237 lines
7.0 KiB
Python
237 lines
7.0 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Comprehensive test script for the query processor module.
|
|
|
|
This script tests all the key functionality of the query processor with the Groq models.
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import time
|
|
from datetime import datetime
|
|
from typing import Dict, Any, List
|
|
|
|
from query.query_processor import QueryProcessor, get_query_processor
|
|
from query.llm_interface import LLMInterface, get_llm_interface
|
|
from config.config import get_config
|
|
|
|
# Create a config.yaml file if it doesn't exist
|
|
config_dir = os.path.join(os.path.dirname(__file__), "config")
|
|
config_file = os.path.join(config_dir, "config.yaml")
|
|
if not os.path.exists(config_file):
|
|
example_file = os.path.join(config_dir, "config.yaml.example")
|
|
if os.path.exists(example_file):
|
|
with open(example_file, "r") as f:
|
|
example_content = f.read()
|
|
|
|
with open(config_file, "w") as f:
|
|
f.write(example_content)
|
|
|
|
print(f"Created config.yaml from example file")
|
|
|
|
# Create a global LLM interface with the Groq model
|
|
groq_interface = get_llm_interface("llama-3.1-8b-instant")
|
|
print(f"Using model: {groq_interface.model_name}")
|
|
|
|
# Monkey patch the get_llm_interface function to always return our Groq interface
|
|
import query.llm_interface
|
|
original_get_llm_interface = query.llm_interface.get_llm_interface
|
|
|
|
def patched_get_llm_interface(*args, **kwargs):
|
|
return groq_interface
|
|
|
|
query.llm_interface.get_llm_interface = patched_get_llm_interface
|
|
|
|
# Test data
|
|
TEST_QUERIES = [
|
|
# Simple factual queries
|
|
"What is quantum computing?",
|
|
"Who invented the internet?",
|
|
|
|
# Complex research queries
|
|
"What are the latest advancements in renewable energy?",
|
|
"How does artificial intelligence impact healthcare?",
|
|
|
|
# Comparative queries
|
|
"Compare machine learning and deep learning",
|
|
"What are the differences between solar and wind energy?",
|
|
|
|
# Domain-specific queries
|
|
"Explain the CRISPR-Cas9 gene editing technology",
|
|
"What are the implications of blockchain for finance?"
|
|
]
|
|
|
|
SEARCH_ENGINES = ["google", "bing", "scholar"]
|
|
|
|
def test_enhance_query(query: str) -> str:
|
|
"""
|
|
Test the query enhancement functionality.
|
|
|
|
Args:
|
|
query: The query to enhance
|
|
|
|
Returns:
|
|
The enhanced query
|
|
"""
|
|
print(f"\nTesting Query Enhancement")
|
|
print(f"Original Query: '{query}'")
|
|
print("-" * 50)
|
|
|
|
start_time = time.time()
|
|
enhanced_query = groq_interface.enhance_query(query)
|
|
end_time = time.time()
|
|
|
|
print(f"Processing time: {end_time - start_time:.2f} seconds")
|
|
print(f"Enhanced Query: '{enhanced_query}'")
|
|
print("-" * 50)
|
|
|
|
return enhanced_query
|
|
|
|
|
|
def test_classify_query(query: str) -> Dict[str, Any]:
|
|
"""
|
|
Test the query classification functionality.
|
|
|
|
Args:
|
|
query: The query to classify
|
|
|
|
Returns:
|
|
The classification result
|
|
"""
|
|
print(f"\nTesting Query Classification")
|
|
print(f"Query: '{query}'")
|
|
print("-" * 50)
|
|
|
|
start_time = time.time()
|
|
classification = groq_interface.classify_query(query)
|
|
end_time = time.time()
|
|
|
|
print(f"Processing time: {end_time - start_time:.2f} seconds")
|
|
print(f"Classification: {json.dumps(classification, indent=2)}")
|
|
print("-" * 50)
|
|
|
|
return classification
|
|
|
|
|
|
def test_process_query(query: str) -> Dict[str, Any]:
|
|
"""
|
|
Test the query processing functionality.
|
|
|
|
Args:
|
|
query: The query to process
|
|
|
|
Returns:
|
|
The processed query result
|
|
"""
|
|
# Get the query processor (which will use our patched LLM interface)
|
|
processor = get_query_processor()
|
|
|
|
# Process the query
|
|
print(f"\nTesting Query Processing")
|
|
print(f"Query: '{query}'")
|
|
print("-" * 50)
|
|
|
|
start_time = time.time()
|
|
result = processor.process_query(query)
|
|
end_time = time.time()
|
|
|
|
# Add timestamp
|
|
result['timestamp'] = datetime.now().isoformat()
|
|
|
|
# Calculate processing time
|
|
print(f"Processing time: {end_time - start_time:.2f} seconds")
|
|
|
|
# Print the result in a formatted way
|
|
print(f"Original Query: {result['original_query']}")
|
|
print(f"Enhanced Query: {result['enhanced_query']}")
|
|
print(f"Query Type: {result['type']}")
|
|
print(f"Query Intent: {result['intent']}")
|
|
print(f"Entities: {', '.join(result['entities'])}")
|
|
print("-" * 50)
|
|
|
|
return result
|
|
|
|
|
|
def test_generate_search_queries(structured_query: Dict[str, Any],
|
|
search_engines: List[str]) -> Dict[str, Any]:
|
|
"""
|
|
Test the search query generation functionality.
|
|
|
|
Args:
|
|
structured_query: The structured query to generate search queries for
|
|
search_engines: List of search engines to generate queries for
|
|
|
|
Returns:
|
|
The updated structured query with search queries
|
|
"""
|
|
# Get the query processor (which will use our patched LLM interface)
|
|
processor = get_query_processor()
|
|
|
|
# Generate search queries
|
|
print(f"\nTesting Search Query Generation")
|
|
print(f"Engines: {', '.join(search_engines)}")
|
|
print("-" * 50)
|
|
|
|
start_time = time.time()
|
|
result = processor.generate_search_queries(structured_query, search_engines)
|
|
end_time = time.time()
|
|
|
|
# Calculate processing time
|
|
print(f"Processing time: {end_time - start_time:.2f} seconds")
|
|
|
|
# Print the generated search queries
|
|
for engine, queries in result['search_queries'].items():
|
|
print(f"\n{engine.upper()} Queries:")
|
|
for i, query in enumerate(queries, 1):
|
|
print(f" {i}. {query}")
|
|
print("-" * 50)
|
|
|
|
return result
|
|
|
|
|
|
def run_comprehensive_tests():
|
|
"""Run comprehensive tests on the query processor."""
|
|
results = []
|
|
|
|
for i, query in enumerate(TEST_QUERIES, 1):
|
|
print(f"\n\nTEST {i}: {query}")
|
|
print("=" * 80)
|
|
|
|
# Test individual components
|
|
enhanced_query = test_enhance_query(query)
|
|
classification = test_classify_query(query)
|
|
|
|
# Test the full query processing pipeline
|
|
structured_query = test_process_query(query)
|
|
|
|
# Test search query generation for a subset of queries
|
|
if i % 2 == 0: # Only test every other query to save time
|
|
search_result = test_generate_search_queries(structured_query, SEARCH_ENGINES)
|
|
structured_query = search_result
|
|
|
|
# Save results
|
|
results.append({
|
|
"query": query,
|
|
"enhanced_query": enhanced_query,
|
|
"classification": classification,
|
|
"structured_query": structured_query
|
|
})
|
|
|
|
print("\n" + "=" * 80 + "\n")
|
|
|
|
# Add a delay between tests to avoid rate limiting
|
|
if i < len(TEST_QUERIES):
|
|
print(f"Waiting 2 seconds before next test...")
|
|
time.sleep(2)
|
|
|
|
# Save results to a file
|
|
output_file = "query_processor_test_results.json"
|
|
with open(output_file, "w") as f:
|
|
json.dump(results, f, indent=2)
|
|
|
|
print(f"\nTest results saved to {output_file}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_comprehensive_tests()
|