ira/tests/report/test_custom_model.py

133 lines
5.0 KiB
Python

#!/usr/bin/env python
"""
Test Query to Report Script with Custom Model
This script tests the query_to_report.py script with a custom model and query.
"""
import os
import sys
import asyncio
import argparse
from datetime import datetime
# Add parent directory to path to import modules
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from scripts.query_to_report import query_to_report
from report.report_detail_levels import get_report_detail_level_manager
from report.report_synthesis import ReportSynthesizer, get_report_synthesizer
from config.config import get_config
async def run_custom_model_test(
query: str,
model_name: str,
detail_level: str = "standard",
use_mock: bool = False,
process_thinking_tags: bool = False
):
"""
Run a test of the query to report workflow with a custom model.
Args:
query: The query to process
model_name: The name of the model to use
detail_level: Level of detail for the report (brief, standard, detailed, comprehensive)
use_mock: If True, use mock data instead of making actual API calls
process_thinking_tags: If True, process and remove <thinking> tags from the model output
"""
# Generate timestamp for unique output file
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_short_name = model_name.split('/')[-1] if '/' in model_name else model_name
output_file = f"report_{timestamp}_{model_short_name}.md"
print(f"Processing query: {query}")
print(f"Model: {model_name}")
print(f"Detail level: {detail_level}")
print(f"Process thinking tags: {process_thinking_tags}")
print(f"This may take a few minutes depending on the number of search results and API response times...")
# Get detail level configuration
detail_level_manager = get_report_detail_level_manager()
config = detail_level_manager.get_detail_level_config(detail_level)
# Print detail level configuration
print(f"\nDetail level configuration:")
print(f" Number of results per search engine: {config.get('num_results')}")
print(f" Token budget: {config.get('token_budget')}")
print(f" Chunk size: {config.get('chunk_size')}")
print(f" Overlap size: {config.get('overlap_size')}")
print(f" Default model: {config.get('model')}")
print(f" Using custom model: {model_name}")
# Create a custom report synthesizer with the specified model
custom_synthesizer = ReportSynthesizer(model_name=model_name)
# Set the process_thinking_tags flag if needed
if process_thinking_tags:
custom_synthesizer.process_thinking_tags = True
# Store the original synthesizer to restore later
original_synthesizer = get_report_synthesizer()
# Replace the global synthesizer with our custom one
from report.report_synthesis import report_synthesizer
report_synthesis_module = sys.modules['report.report_synthesis']
report_synthesis_module.report_synthesizer = custom_synthesizer
try:
# Run the workflow
await query_to_report(
query=query,
output_file=output_file,
detail_level=detail_level,
use_mock=use_mock
)
print(f"\nTest completed successfully!")
print(f"Report saved to: {output_file}")
# Print the first few lines of the report
try:
with open(output_file, 'r', encoding='utf-8') as f:
preview = f.read(1000) # Show a larger preview
print("\nReport Preview:")
print("-" * 80)
print(preview + "...")
print("-" * 80)
except Exception as e:
print(f"Error reading report: {e}")
finally:
# Restore the original synthesizer
report_synthesis_module.report_synthesizer = original_synthesizer
def main():
"""Main function to parse arguments and run the test."""
parser = argparse.ArgumentParser(description='Test the query to report workflow with a custom model')
parser.add_argument('query', help='The query to process')
parser.add_argument('--model', '-m', required=True, help='The model to use (e.g., groq/deepseek-r1-distill-llama-70b-specdec)')
parser.add_argument('--detail-level', '-d', type=str, default='standard',
choices=['brief', 'standard', 'detailed', 'comprehensive'],
help='Level of detail for the report')
parser.add_argument('--use-mock', action='store_true', help='Use mock data instead of API calls')
parser.add_argument('--process-thinking-tags', '-t', action='store_true',
help='Process and remove <thinking> tags from model output')
args = parser.parse_args()
# Run the test
asyncio.run(run_custom_model_test(
query=args.query,
model_name=args.model,
detail_level=args.detail_level,
use_mock=args.use_mock,
process_thinking_tags=args.process_thinking_tags
))
if __name__ == "__main__":
main()