ira/tests/query/test_query_processor.py

149 lines
4.6 KiB
Python

#!/usr/bin/env python3
"""
Test script for the query processor module.
This script tests the query processor with the Groq models.
"""
import os
import json
from datetime import datetime
from typing import Dict, Any
from query.query_processor import QueryProcessor, get_query_processor
from query.llm_interface import LLMInterface, get_llm_interface
from config.config import get_config
# Create a config.yaml file if it doesn't exist
config_dir = os.path.join(os.path.dirname(__file__), "config")
config_file = os.path.join(config_dir, "config.yaml")
if not os.path.exists(config_file):
example_file = os.path.join(config_dir, "config.yaml.example")
if os.path.exists(example_file):
with open(example_file, "r") as f:
example_content = f.read()
with open(config_file, "w") as f:
f.write(example_content)
print(f"Created config.yaml from example file")
# Force the use of Groq model for testing
# First, create a global LLM interface with the Groq model
groq_interface = get_llm_interface("llama-3.1-8b-instant")
print(f"Using model: {groq_interface.model_name}")
# Monkey patch the get_llm_interface function to always return our Groq interface
import query.llm_interface
original_get_llm_interface = query.llm_interface.get_llm_interface
def patched_get_llm_interface(*args, **kwargs):
return groq_interface
query.llm_interface.get_llm_interface = patched_get_llm_interface
def test_process_query(query: str) -> Dict[str, Any]:
"""
Test the query processing functionality.
Args:
query: The query to process
Returns:
The processed query result
"""
# Get the query processor (which will use our patched LLM interface)
processor = get_query_processor()
# Process the query
print(f"\nProcessing query: '{query}'")
print("-" * 50)
start_time = datetime.now()
result = processor.process_query(query)
end_time = datetime.now()
# Add timestamp
result['timestamp'] = datetime.now().isoformat()
# Calculate processing time
processing_time = (end_time - start_time).total_seconds()
print(f"Processing time: {processing_time:.2f} seconds")
# Print the result in a formatted way
print("\nProcessed Query Result:")
print("-" * 50)
print(f"Original Query: {result['original_query']}")
print(f"Enhanced Query: {result['enhanced_query']}")
print(f"Query Type: {result['type']}")
print(f"Query Intent: {result['intent']}")
print(f"Entities: {', '.join(result['entities'])}")
print("-" * 50)
return result
def test_generate_search_queries(structured_query: Dict[str, Any],
search_engines: list = None) -> Dict[str, Any]:
"""
Test the search query generation functionality.
Args:
structured_query: The structured query to generate search queries for
search_engines: List of search engines to generate queries for
Returns:
The updated structured query with search queries
"""
if search_engines is None:
search_engines = ["google", "bing", "scholar"]
# Get the query processor (which will use our patched LLM interface)
processor = get_query_processor()
# Generate search queries
print(f"\nGenerating search queries for engines: {', '.join(search_engines)}")
print("-" * 50)
start_time = datetime.now()
result = processor.generate_search_queries(structured_query, search_engines)
end_time = datetime.now()
# Calculate processing time
processing_time = (end_time - start_time).total_seconds()
print(f"Processing time: {processing_time:.2f} seconds")
# Print the generated search queries
print("\nGenerated Search Queries:")
print("-" * 50)
for engine, queries in result['search_queries'].items():
print(f"\n{engine.upper()} Queries:")
for i, query in enumerate(queries, 1):
print(f" {i}. {query}")
print("-" * 50)
return result
def main():
"""Run the query processor tests."""
# Test queries
test_queries = [
"What are the latest advancements in quantum computing?",
"Compare renewable energy sources and their efficiency",
"Explain the impact of artificial intelligence on healthcare"
]
# Process each query
for query in test_queries:
structured_query = test_process_query(query)
# Generate search queries for the processed query
test_generate_search_queries(structured_query)
print("\n" + "=" * 80 + "\n")
if __name__ == "__main__":
main()