ira/tests/query/test_domain_classification.py

210 lines
8.8 KiB
Python

"""
Test the query domain classification functionality.
This script tests the new LLM-based query domain classification functionality
to ensure it correctly classifies queries into academic, code, current_events,
and general categories.
"""
import os
import sys
import json
import asyncio
from typing import Dict, Any, List
# Add parent directory to path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
from query.llm_interface import get_llm_interface
from query.query_processor import get_query_processor
async def test_classify_query_domain():
"""Test the classify_query_domain function."""
llm_interface = get_llm_interface()
test_queries = [
# Academic queries
"What are the technological, economic, and social implications of large language models in today's society?",
"What is the current state of research on quantum computing algorithms?",
"How has climate change affected biodiversity in marine ecosystems?",
# Code queries
"How do I implement a transformer model in PyTorch for text classification?",
"What's the best way to optimize a recursive function in Python?",
"Explain how to use React hooks with TypeScript",
# Current events queries
"What are the latest developments in the Ukraine conflict?",
"How has the Federal Reserve's recent interest rate decision affected the stock market?",
"What were the outcomes of the recent climate summit?",
# Mixed or general queries
"How are LLMs being used to detect and prevent cyber attacks?",
"What are the best practices for remote work?",
"Compare electric vehicles to traditional gas-powered cars"
]
results = []
for query in test_queries:
print(f"\nClassifying query: {query}")
domain_classification = await llm_interface.classify_query_domain(query)
print(f"Primary type: {domain_classification.get('primary_type')} (confidence: {domain_classification.get('confidence')})")
if domain_classification.get('secondary_types'):
for sec_type in domain_classification.get('secondary_types'):
print(f"Secondary type: {sec_type['type']} (confidence: {sec_type['confidence']})")
print(f"Reasoning: {domain_classification.get('reasoning', 'None provided')}")
results.append({
'query': query,
'classification': domain_classification
})
# Save results to a file
with open('domain_classification_results.json', 'w') as f:
json.dump(results, indent=2, fp=f)
print(f"\nResults saved to domain_classification_results.json")
async def test_query_processor_with_domain_classification():
"""Test the query processor with the new domain classification."""
query_processor = get_query_processor()
test_queries = [
"What are the technological implications of large language models?",
"How do I implement a transformer model in PyTorch?",
"What are the latest developments in the Ukraine conflict?",
"How are LLMs being used to detect cyber attacks?"
]
results = []
for query in test_queries:
print(f"\nProcessing query: {query}")
structured_query = await query_processor.process_query(query)
print(f"Domain: {structured_query.get('domain')} (confidence: {structured_query.get('domain_confidence')})")
print(f"Is academic: {structured_query.get('is_academic')}")
print(f"Is code: {structured_query.get('is_code')}")
print(f"Is current events: {structured_query.get('is_current_events')}")
if structured_query.get('secondary_domains'):
for domain in structured_query.get('secondary_domains'):
print(f"Secondary domain: {domain['type']} (confidence: {domain['confidence']})")
print(f"Reasoning: {structured_query.get('classification_reasoning', 'None provided')}")
results.append({
'query': query,
'structured_query': {
'domain': structured_query.get('domain'),
'domain_confidence': structured_query.get('domain_confidence'),
'is_academic': structured_query.get('is_academic'),
'is_code': structured_query.get('is_code'),
'is_current_events': structured_query.get('is_current_events'),
'secondary_domains': structured_query.get('secondary_domains'),
'classification_reasoning': structured_query.get('classification_reasoning')
}
})
# Save results to a file
with open('query_processor_domain_results.json', 'w') as f:
json.dump(results, indent=2, fp=f)
print(f"\nResults saved to query_processor_domain_results.json")
async def compare_with_keyword_classification():
"""Compare LLM-based classification with keyword-based classification."""
query_processor = get_query_processor()
# Monkey patch the query processor to use keyword-based classification
original_structure_query_with_llm = query_processor._structure_query_with_llm
# Test queries that might be challenging for keyword-based approach
test_queries = [
"How do language models work internally?", # Could be academic or code
"What are the best machine learning models for text generation?", # "models" could trigger code
"How has ChatGPT changed the AI landscape?", # Recent but academic topic
"What techniques help in understanding neural networks?", # Could be academic or code
"How are transformers used in NLP applications?", # Ambiguous - could mean electrical transformers or ML
]
results = []
for query in test_queries:
print(f"\nProcessing query with both methods: {query}")
# First, use LLM-based classification (normal operation)
structured_query_llm = await query_processor.process_query(query)
# Now, force keyword-based classification by monkey patching
query_processor._structure_query_with_llm = query_processor._structure_query
structured_query_keyword = await query_processor.process_query(query)
# Restore original method
query_processor._structure_query_with_llm = original_structure_query_with_llm
# Compare results
print(f"LLM Classification:")
print(f" Domain: {structured_query_llm.get('domain')}")
print(f" Is academic: {structured_query_llm.get('is_academic')}")
print(f" Is code: {structured_query_llm.get('is_code')}")
print(f" Is current events: {structured_query_llm.get('is_current_events')}")
print(f"Keyword Classification:")
print(f" Is academic: {structured_query_keyword.get('is_academic')}")
print(f" Is code: {structured_query_keyword.get('is_code')}")
print(f" Is current events: {structured_query_keyword.get('is_current_events')}")
results.append({
'query': query,
'llm_classification': {
'domain': structured_query_llm.get('domain'),
'is_academic': structured_query_llm.get('is_academic'),
'is_code': structured_query_llm.get('is_code'),
'is_current_events': structured_query_llm.get('is_current_events')
},
'keyword_classification': {
'is_academic': structured_query_keyword.get('is_academic'),
'is_code': structured_query_keyword.get('is_code'),
'is_current_events': structured_query_keyword.get('is_current_events')
}
})
# Save comparison results to a file
with open('classification_comparison_results.json', 'w') as f:
json.dump(results, indent=2, fp=f)
print(f"\nComparison results saved to classification_comparison_results.json")
async def main():
"""Run tests for query domain classification."""
# Choose which test to run
test_type = 1 # Change to 1, 2, or 3 to run different tests
if test_type == 1:
print("=== Testing classify_query_domain function ===")
await test_classify_query_domain()
elif test_type == 2:
print("=== Testing query processor with domain classification ===")
await test_query_processor_with_domain_classification()
elif test_type == 3:
print("=== Comparing LLM and keyword classifications ===")
await compare_with_keyword_classification()
else:
print("=== Running all tests ===")
await test_classify_query_domain()
await test_query_processor_with_domain_classification()
await compare_with_keyword_classification()
if __name__ == "__main__":
asyncio.run(main())