ira/test_query_processor_compre...

237 lines
7.0 KiB
Python

#!/usr/bin/env python3
"""
Comprehensive test script for the query processor module.
This script tests all the key functionality of the query processor with the Groq models.
"""
import os
import json
import time
from datetime import datetime
from typing import Dict, Any, List
from query.query_processor import QueryProcessor, get_query_processor
from query.llm_interface import LLMInterface, get_llm_interface
from config.config import get_config
# Create a config.yaml file if it doesn't exist
config_dir = os.path.join(os.path.dirname(__file__), "config")
config_file = os.path.join(config_dir, "config.yaml")
if not os.path.exists(config_file):
example_file = os.path.join(config_dir, "config.yaml.example")
if os.path.exists(example_file):
with open(example_file, "r") as f:
example_content = f.read()
with open(config_file, "w") as f:
f.write(example_content)
print(f"Created config.yaml from example file")
# Create a global LLM interface with the Groq model
groq_interface = get_llm_interface("llama-3.1-8b-instant")
print(f"Using model: {groq_interface.model_name}")
# Monkey patch the get_llm_interface function to always return our Groq interface
import query.llm_interface
original_get_llm_interface = query.llm_interface.get_llm_interface
def patched_get_llm_interface(*args, **kwargs):
return groq_interface
query.llm_interface.get_llm_interface = patched_get_llm_interface
# Test data
TEST_QUERIES = [
# Simple factual queries
"What is quantum computing?",
"Who invented the internet?",
# Complex research queries
"What are the latest advancements in renewable energy?",
"How does artificial intelligence impact healthcare?",
# Comparative queries
"Compare machine learning and deep learning",
"What are the differences between solar and wind energy?",
# Domain-specific queries
"Explain the CRISPR-Cas9 gene editing technology",
"What are the implications of blockchain for finance?"
]
SEARCH_ENGINES = ["google", "bing", "scholar"]
def test_enhance_query(query: str) -> str:
"""
Test the query enhancement functionality.
Args:
query: The query to enhance
Returns:
The enhanced query
"""
print(f"\nTesting Query Enhancement")
print(f"Original Query: '{query}'")
print("-" * 50)
start_time = time.time()
enhanced_query = groq_interface.enhance_query(query)
end_time = time.time()
print(f"Processing time: {end_time - start_time:.2f} seconds")
print(f"Enhanced Query: '{enhanced_query}'")
print("-" * 50)
return enhanced_query
def test_classify_query(query: str) -> Dict[str, Any]:
"""
Test the query classification functionality.
Args:
query: The query to classify
Returns:
The classification result
"""
print(f"\nTesting Query Classification")
print(f"Query: '{query}'")
print("-" * 50)
start_time = time.time()
classification = groq_interface.classify_query(query)
end_time = time.time()
print(f"Processing time: {end_time - start_time:.2f} seconds")
print(f"Classification: {json.dumps(classification, indent=2)}")
print("-" * 50)
return classification
def test_process_query(query: str) -> Dict[str, Any]:
"""
Test the query processing functionality.
Args:
query: The query to process
Returns:
The processed query result
"""
# Get the query processor (which will use our patched LLM interface)
processor = get_query_processor()
# Process the query
print(f"\nTesting Query Processing")
print(f"Query: '{query}'")
print("-" * 50)
start_time = time.time()
result = processor.process_query(query)
end_time = time.time()
# Add timestamp
result['timestamp'] = datetime.now().isoformat()
# Calculate processing time
print(f"Processing time: {end_time - start_time:.2f} seconds")
# Print the result in a formatted way
print(f"Original Query: {result['original_query']}")
print(f"Enhanced Query: {result['enhanced_query']}")
print(f"Query Type: {result['type']}")
print(f"Query Intent: {result['intent']}")
print(f"Entities: {', '.join(result['entities'])}")
print("-" * 50)
return result
def test_generate_search_queries(structured_query: Dict[str, Any],
search_engines: List[str]) -> Dict[str, Any]:
"""
Test the search query generation functionality.
Args:
structured_query: The structured query to generate search queries for
search_engines: List of search engines to generate queries for
Returns:
The updated structured query with search queries
"""
# Get the query processor (which will use our patched LLM interface)
processor = get_query_processor()
# Generate search queries
print(f"\nTesting Search Query Generation")
print(f"Engines: {', '.join(search_engines)}")
print("-" * 50)
start_time = time.time()
result = processor.generate_search_queries(structured_query, search_engines)
end_time = time.time()
# Calculate processing time
print(f"Processing time: {end_time - start_time:.2f} seconds")
# Print the generated search queries
for engine, queries in result['search_queries'].items():
print(f"\n{engine.upper()} Queries:")
for i, query in enumerate(queries, 1):
print(f" {i}. {query}")
print("-" * 50)
return result
def run_comprehensive_tests():
"""Run comprehensive tests on the query processor."""
results = []
for i, query in enumerate(TEST_QUERIES, 1):
print(f"\n\nTEST {i}: {query}")
print("=" * 80)
# Test individual components
enhanced_query = test_enhance_query(query)
classification = test_classify_query(query)
# Test the full query processing pipeline
structured_query = test_process_query(query)
# Test search query generation for a subset of queries
if i % 2 == 0: # Only test every other query to save time
search_result = test_generate_search_queries(structured_query, SEARCH_ENGINES)
structured_query = search_result
# Save results
results.append({
"query": query,
"enhanced_query": enhanced_query,
"classification": classification,
"structured_query": structured_query
})
print("\n" + "=" * 80 + "\n")
# Add a delay between tests to avoid rate limiting
if i < len(TEST_QUERIES):
print(f"Waiting 2 seconds before next test...")
time.sleep(2)
# Save results to a file
output_file = "query_processor_test_results.json"
with open(output_file, "w") as f:
json.dump(results, f, indent=2)
print(f"\nTest results saved to {output_file}")
if __name__ == "__main__":
run_comprehensive_tests()