Add comprehensive test script to verify model provider selection fix

This commit is contained in:
Steve White 2025-03-19 08:05:16 -05:00
parent 15357890ea
commit 4d622de48d
1 changed files with 91 additions and 0 deletions

View File

@ -0,0 +1,91 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Test script for the report synthesis module, specifically to verify
model provider selection works correctly.
"""
import os
import sys
import asyncio
import logging
from typing import Dict, Any, List
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Add parent directory to path to import modules
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config.config import Config
from report.report_synthesis import ReportSynthesizer
async def test_model_provider_selection():
"""Test that model provider selection works correctly."""
# Initialize config
config = Config()
# Test with different models and providers
models_to_test = [
{"provider": "groq", "model_name": "llama-3.3-70b-versatile"},
{"provider": "gemini", "model_name": "gemini-2.0-flash"},
{"provider": "anthropic", "model_name": "claude-3-opus-20240229"},
{"provider": "openai", "model_name": "gpt-4-turbo"},
]
for model_config in models_to_test:
provider = model_config["provider"]
model_name = model_config["model_name"]
logger.info(f"\n\n===== Testing model: {model_name} with provider: {provider} =====")
# Create a synthesizer with the specified model
# First update the config to use the specified provider
config.config_data['models'] = config.config_data.get('models', {})
config.config_data['models'][model_name] = {
"provider": provider,
"model_name": model_name,
"temperature": 0.5,
"max_tokens": 2048,
"top_p": 1.0
}
# Create the synthesizer with the model name
synthesizer = ReportSynthesizer(model_name=model_name)
# Verify the model and provider are set correctly
logger.info(f"Synthesizer initialized with model: {synthesizer.model_name}")
logger.info(f"Synthesizer provider: {synthesizer.model_config.get('provider')}")
# Get completion parameters to verify they're set correctly
params = synthesizer._get_completion_params()
logger.info(f"Completion parameters: {params}")
# Create a simple test message
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Say hello and identify yourself as the model you're running on."}
]
try:
# Test the generate_completion method
logger.info("Testing generate_completion method...")
response = await synthesizer.generate_completion(messages)
logger.info(f"Response: {response[:100]}...") # Show first 100 chars
except Exception as e:
logger.error(f"Error generating completion: {e}")
# Continue with the next model even if this one fails
continue
logger.info(f"===== Test completed for {model_name} with provider {provider} =====\n")
async def main():
"""Main function to run tests."""
logger.info("Starting report synthesis tests...")
await test_model_provider_selection()
logger.info("All tests completed.")
if __name__ == "__main__":
asyncio.run(main())