diff --git a/report/report_synthesis_test.py b/report/report_synthesis_test.py new file mode 100644 index 0000000..d47a323 --- /dev/null +++ b/report/report_synthesis_test.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +Test script for the report synthesis module, specifically to verify +model provider selection works correctly. +""" + +import os +import sys +import asyncio +import logging +from typing import Dict, Any, List + +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Add parent directory to path to import modules +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from config.config import Config +from report.report_synthesis import ReportSynthesizer + +async def test_model_provider_selection(): + """Test that model provider selection works correctly.""" + # Initialize config + config = Config() + + # Test with different models and providers + models_to_test = [ + {"provider": "groq", "model_name": "llama-3.3-70b-versatile"}, + {"provider": "gemini", "model_name": "gemini-2.0-flash"}, + {"provider": "anthropic", "model_name": "claude-3-opus-20240229"}, + {"provider": "openai", "model_name": "gpt-4-turbo"}, + ] + + for model_config in models_to_test: + provider = model_config["provider"] + model_name = model_config["model_name"] + + logger.info(f"\n\n===== Testing model: {model_name} with provider: {provider} =====") + + # Create a synthesizer with the specified model + # First update the config to use the specified provider + config.config_data['models'] = config.config_data.get('models', {}) + config.config_data['models'][model_name] = { + "provider": provider, + "model_name": model_name, + "temperature": 0.5, + "max_tokens": 2048, + "top_p": 1.0 + } + + # Create the synthesizer with the model name + synthesizer = ReportSynthesizer(model_name=model_name) + + # Verify the model and provider are set correctly + logger.info(f"Synthesizer initialized with model: {synthesizer.model_name}") + logger.info(f"Synthesizer provider: {synthesizer.model_config.get('provider')}") + + # Get completion parameters to verify they're set correctly + params = synthesizer._get_completion_params() + logger.info(f"Completion parameters: {params}") + + # Create a simple test message + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Say hello and identify yourself as the model you're running on."} + ] + + try: + # Test the generate_completion method + logger.info("Testing generate_completion method...") + response = await synthesizer.generate_completion(messages) + logger.info(f"Response: {response[:100]}...") # Show first 100 chars + except Exception as e: + logger.error(f"Error generating completion: {e}") + # Continue with the next model even if this one fails + continue + + logger.info(f"===== Test completed for {model_name} with provider {provider} =====\n") + +async def main(): + """Main function to run tests.""" + logger.info("Starting report synthesis tests...") + await test_model_provider_selection() + logger.info("All tests completed.") + +if __name__ == "__main__": + asyncio.run(main())