Add comprehensive test script to verify model provider selection fix
This commit is contained in:
parent
15357890ea
commit
4d622de48d
|
@ -0,0 +1,91 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Test script for the report synthesis module, specifically to verify
|
||||
model provider selection works correctly.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Any, List
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Add parent directory to path to import modules
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from config.config import Config
|
||||
from report.report_synthesis import ReportSynthesizer
|
||||
|
||||
async def test_model_provider_selection():
|
||||
"""Test that model provider selection works correctly."""
|
||||
# Initialize config
|
||||
config = Config()
|
||||
|
||||
# Test with different models and providers
|
||||
models_to_test = [
|
||||
{"provider": "groq", "model_name": "llama-3.3-70b-versatile"},
|
||||
{"provider": "gemini", "model_name": "gemini-2.0-flash"},
|
||||
{"provider": "anthropic", "model_name": "claude-3-opus-20240229"},
|
||||
{"provider": "openai", "model_name": "gpt-4-turbo"},
|
||||
]
|
||||
|
||||
for model_config in models_to_test:
|
||||
provider = model_config["provider"]
|
||||
model_name = model_config["model_name"]
|
||||
|
||||
logger.info(f"\n\n===== Testing model: {model_name} with provider: {provider} =====")
|
||||
|
||||
# Create a synthesizer with the specified model
|
||||
# First update the config to use the specified provider
|
||||
config.config_data['models'] = config.config_data.get('models', {})
|
||||
config.config_data['models'][model_name] = {
|
||||
"provider": provider,
|
||||
"model_name": model_name,
|
||||
"temperature": 0.5,
|
||||
"max_tokens": 2048,
|
||||
"top_p": 1.0
|
||||
}
|
||||
|
||||
# Create the synthesizer with the model name
|
||||
synthesizer = ReportSynthesizer(model_name=model_name)
|
||||
|
||||
# Verify the model and provider are set correctly
|
||||
logger.info(f"Synthesizer initialized with model: {synthesizer.model_name}")
|
||||
logger.info(f"Synthesizer provider: {synthesizer.model_config.get('provider')}")
|
||||
|
||||
# Get completion parameters to verify they're set correctly
|
||||
params = synthesizer._get_completion_params()
|
||||
logger.info(f"Completion parameters: {params}")
|
||||
|
||||
# Create a simple test message
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Say hello and identify yourself as the model you're running on."}
|
||||
]
|
||||
|
||||
try:
|
||||
# Test the generate_completion method
|
||||
logger.info("Testing generate_completion method...")
|
||||
response = await synthesizer.generate_completion(messages)
|
||||
logger.info(f"Response: {response[:100]}...") # Show first 100 chars
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating completion: {e}")
|
||||
# Continue with the next model even if this one fails
|
||||
continue
|
||||
|
||||
logger.info(f"===== Test completed for {model_name} with provider {provider} =====\n")
|
||||
|
||||
async def main():
|
||||
"""Main function to run tests."""
|
||||
logger.info("Starting report synthesis tests...")
|
||||
await test_model_provider_selection()
|
||||
logger.info("All tests completed.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
Loading…
Reference in New Issue