92 lines
3.3 KiB
Python
92 lines
3.3 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
Test script for the report synthesis module, specifically to verify
|
|
model provider selection works correctly.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import asyncio
|
|
import logging
|
|
from typing import Dict, Any, List
|
|
|
|
# Set up logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Add parent directory to path to import modules
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from config.config import Config
|
|
from report.report_synthesis import ReportSynthesizer
|
|
|
|
async def test_model_provider_selection():
|
|
"""Test that model provider selection works correctly."""
|
|
# Initialize config
|
|
config = Config()
|
|
|
|
# Test with different models and providers
|
|
models_to_test = [
|
|
{"provider": "groq", "model_name": "llama-3.3-70b-versatile"},
|
|
{"provider": "gemini", "model_name": "gemini-2.0-flash"},
|
|
{"provider": "anthropic", "model_name": "claude-3-opus-20240229"},
|
|
{"provider": "openai", "model_name": "gpt-4-turbo"},
|
|
]
|
|
|
|
for model_config in models_to_test:
|
|
provider = model_config["provider"]
|
|
model_name = model_config["model_name"]
|
|
|
|
logger.info(f"\n\n===== Testing model: {model_name} with provider: {provider} =====")
|
|
|
|
# Create a synthesizer with the specified model
|
|
# First update the config to use the specified provider
|
|
config.config_data['models'] = config.config_data.get('models', {})
|
|
config.config_data['models'][model_name] = {
|
|
"provider": provider,
|
|
"model_name": model_name,
|
|
"temperature": 0.5,
|
|
"max_tokens": 2048,
|
|
"top_p": 1.0
|
|
}
|
|
|
|
# Create the synthesizer with the model name
|
|
synthesizer = ReportSynthesizer(model_name=model_name)
|
|
|
|
# Verify the model and provider are set correctly
|
|
logger.info(f"Synthesizer initialized with model: {synthesizer.model_name}")
|
|
logger.info(f"Synthesizer provider: {synthesizer.model_config.get('provider')}")
|
|
|
|
# Get completion parameters to verify they're set correctly
|
|
params = synthesizer._get_completion_params()
|
|
logger.info(f"Completion parameters: {params}")
|
|
|
|
# Create a simple test message
|
|
messages = [
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "user", "content": "Say hello and identify yourself as the model you're running on."}
|
|
]
|
|
|
|
try:
|
|
# Test the generate_completion method
|
|
logger.info("Testing generate_completion method...")
|
|
response = await synthesizer.generate_completion(messages)
|
|
logger.info(f"Response: {response[:100]}...") # Show first 100 chars
|
|
except Exception as e:
|
|
logger.error(f"Error generating completion: {e}")
|
|
# Continue with the next model even if this one fails
|
|
continue
|
|
|
|
logger.info(f"===== Test completed for {model_name} with provider {provider} =====\n")
|
|
|
|
async def main():
|
|
"""Main function to run tests."""
|
|
logger.info("Starting report synthesis tests...")
|
|
await test_model_provider_selection()
|
|
logger.info("All tests completed.")
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|