210 lines
8.8 KiB
Python
210 lines
8.8 KiB
Python
"""
|
|
Test the query domain classification functionality.
|
|
|
|
This script tests the new LLM-based query domain classification functionality
|
|
to ensure it correctly classifies queries into academic, code, current_events,
|
|
and general categories.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import json
|
|
import asyncio
|
|
from typing import Dict, Any, List
|
|
|
|
# Add parent directory to path
|
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
|
|
|
|
from query.llm_interface import get_llm_interface
|
|
from query.query_processor import get_query_processor
|
|
|
|
|
|
async def test_classify_query_domain():
|
|
"""Test the classify_query_domain function."""
|
|
llm_interface = get_llm_interface()
|
|
|
|
test_queries = [
|
|
# Academic queries
|
|
"What are the technological, economic, and social implications of large language models in today's society?",
|
|
"What is the current state of research on quantum computing algorithms?",
|
|
"How has climate change affected biodiversity in marine ecosystems?",
|
|
|
|
# Code queries
|
|
"How do I implement a transformer model in PyTorch for text classification?",
|
|
"What's the best way to optimize a recursive function in Python?",
|
|
"Explain how to use React hooks with TypeScript",
|
|
|
|
# Current events queries
|
|
"What are the latest developments in the Ukraine conflict?",
|
|
"How has the Federal Reserve's recent interest rate decision affected the stock market?",
|
|
"What were the outcomes of the recent climate summit?",
|
|
|
|
# Mixed or general queries
|
|
"How are LLMs being used to detect and prevent cyber attacks?",
|
|
"What are the best practices for remote work?",
|
|
"Compare electric vehicles to traditional gas-powered cars"
|
|
]
|
|
|
|
results = []
|
|
|
|
for query in test_queries:
|
|
print(f"\nClassifying query: {query}")
|
|
domain_classification = await llm_interface.classify_query_domain(query)
|
|
|
|
print(f"Primary type: {domain_classification.get('primary_type')} (confidence: {domain_classification.get('confidence')})")
|
|
|
|
if domain_classification.get('secondary_types'):
|
|
for sec_type in domain_classification.get('secondary_types'):
|
|
print(f"Secondary type: {sec_type['type']} (confidence: {sec_type['confidence']})")
|
|
|
|
print(f"Reasoning: {domain_classification.get('reasoning', 'None provided')}")
|
|
|
|
results.append({
|
|
'query': query,
|
|
'classification': domain_classification
|
|
})
|
|
|
|
# Save results to a file
|
|
with open('domain_classification_results.json', 'w') as f:
|
|
json.dump(results, indent=2, fp=f)
|
|
|
|
print(f"\nResults saved to domain_classification_results.json")
|
|
|
|
|
|
async def test_query_processor_with_domain_classification():
|
|
"""Test the query processor with the new domain classification."""
|
|
query_processor = get_query_processor()
|
|
|
|
test_queries = [
|
|
"What are the technological implications of large language models?",
|
|
"How do I implement a transformer model in PyTorch?",
|
|
"What are the latest developments in the Ukraine conflict?",
|
|
"How are LLMs being used to detect cyber attacks?"
|
|
]
|
|
|
|
results = []
|
|
|
|
for query in test_queries:
|
|
print(f"\nProcessing query: {query}")
|
|
structured_query = await query_processor.process_query(query)
|
|
|
|
print(f"Domain: {structured_query.get('domain')} (confidence: {structured_query.get('domain_confidence')})")
|
|
print(f"Is academic: {structured_query.get('is_academic')}")
|
|
print(f"Is code: {structured_query.get('is_code')}")
|
|
print(f"Is current events: {structured_query.get('is_current_events')}")
|
|
|
|
if structured_query.get('secondary_domains'):
|
|
for domain in structured_query.get('secondary_domains'):
|
|
print(f"Secondary domain: {domain['type']} (confidence: {domain['confidence']})")
|
|
|
|
print(f"Reasoning: {structured_query.get('classification_reasoning', 'None provided')}")
|
|
|
|
results.append({
|
|
'query': query,
|
|
'structured_query': {
|
|
'domain': structured_query.get('domain'),
|
|
'domain_confidence': structured_query.get('domain_confidence'),
|
|
'is_academic': structured_query.get('is_academic'),
|
|
'is_code': structured_query.get('is_code'),
|
|
'is_current_events': structured_query.get('is_current_events'),
|
|
'secondary_domains': structured_query.get('secondary_domains'),
|
|
'classification_reasoning': structured_query.get('classification_reasoning')
|
|
}
|
|
})
|
|
|
|
# Save results to a file
|
|
with open('query_processor_domain_results.json', 'w') as f:
|
|
json.dump(results, indent=2, fp=f)
|
|
|
|
print(f"\nResults saved to query_processor_domain_results.json")
|
|
|
|
|
|
async def compare_with_keyword_classification():
|
|
"""Compare LLM-based classification with keyword-based classification."""
|
|
query_processor = get_query_processor()
|
|
|
|
# Monkey patch the query processor to use keyword-based classification
|
|
original_structure_query_with_llm = query_processor._structure_query_with_llm
|
|
|
|
# Test queries that might be challenging for keyword-based approach
|
|
test_queries = [
|
|
"How do language models work internally?", # Could be academic or code
|
|
"What are the best machine learning models for text generation?", # "models" could trigger code
|
|
"How has ChatGPT changed the AI landscape?", # Recent but academic topic
|
|
"What techniques help in understanding neural networks?", # Could be academic or code
|
|
"How are transformers used in NLP applications?", # Ambiguous - could mean electrical transformers or ML
|
|
]
|
|
|
|
results = []
|
|
|
|
for query in test_queries:
|
|
print(f"\nProcessing query with both methods: {query}")
|
|
|
|
# First, use LLM-based classification (normal operation)
|
|
structured_query_llm = await query_processor.process_query(query)
|
|
|
|
# Now, force keyword-based classification by monkey patching
|
|
query_processor._structure_query_with_llm = query_processor._structure_query
|
|
structured_query_keyword = await query_processor.process_query(query)
|
|
|
|
# Restore original method
|
|
query_processor._structure_query_with_llm = original_structure_query_with_llm
|
|
|
|
# Compare results
|
|
print(f"LLM Classification:")
|
|
print(f" Domain: {structured_query_llm.get('domain')}")
|
|
print(f" Is academic: {structured_query_llm.get('is_academic')}")
|
|
print(f" Is code: {structured_query_llm.get('is_code')}")
|
|
print(f" Is current events: {structured_query_llm.get('is_current_events')}")
|
|
|
|
print(f"Keyword Classification:")
|
|
print(f" Is academic: {structured_query_keyword.get('is_academic')}")
|
|
print(f" Is code: {structured_query_keyword.get('is_code')}")
|
|
print(f" Is current events: {structured_query_keyword.get('is_current_events')}")
|
|
|
|
results.append({
|
|
'query': query,
|
|
'llm_classification': {
|
|
'domain': structured_query_llm.get('domain'),
|
|
'is_academic': structured_query_llm.get('is_academic'),
|
|
'is_code': structured_query_llm.get('is_code'),
|
|
'is_current_events': structured_query_llm.get('is_current_events')
|
|
},
|
|
'keyword_classification': {
|
|
'is_academic': structured_query_keyword.get('is_academic'),
|
|
'is_code': structured_query_keyword.get('is_code'),
|
|
'is_current_events': structured_query_keyword.get('is_current_events')
|
|
}
|
|
})
|
|
|
|
# Save comparison results to a file
|
|
with open('classification_comparison_results.json', 'w') as f:
|
|
json.dump(results, indent=2, fp=f)
|
|
|
|
print(f"\nComparison results saved to classification_comparison_results.json")
|
|
|
|
|
|
async def main():
|
|
"""Run tests for query domain classification."""
|
|
# Choose which test to run
|
|
test_type = 1 # Change to 1, 2, or 3 to run different tests
|
|
|
|
if test_type == 1:
|
|
print("=== Testing classify_query_domain function ===")
|
|
await test_classify_query_domain()
|
|
elif test_type == 2:
|
|
print("=== Testing query processor with domain classification ===")
|
|
await test_query_processor_with_domain_classification()
|
|
elif test_type == 3:
|
|
print("=== Comparing LLM and keyword classifications ===")
|
|
await compare_with_keyword_classification()
|
|
else:
|
|
print("=== Running all tests ===")
|
|
await test_classify_query_domain()
|
|
await test_query_processor_with_domain_classification()
|
|
await compare_with_keyword_classification()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|