#!/usr/bin/env python # -*- coding: utf-8 -*- """ Test script for the report synthesis module, specifically to verify model provider selection works correctly. """ import os import sys import asyncio import logging from typing import Dict, Any, List # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Add parent directory to path to import modules sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from config.config import Config from report.report_synthesis import ReportSynthesizer async def test_model_provider_selection(): """Test that model provider selection works correctly.""" logger.info("=== Testing basic model provider selection ===") # Initialize config config = Config() # Test with different models and providers models_to_test = [ {"provider": "groq", "model_name": "llama-3.3-70b-versatile"}, {"provider": "gemini", "model_name": "gemini-2.0-flash"}, {"provider": "anthropic", "model_name": "claude-3-opus-20240229"}, {"provider": "openai", "model_name": "gpt-4-turbo"}, ] for model_config in models_to_test: provider = model_config["provider"] model_name = model_config["model_name"] logger.info(f"\n\n===== Testing model: {model_name} with provider: {provider} =====") # Create a synthesizer with the specified model # First update the config to use the specified provider config.config_data['models'] = config.config_data.get('models', {}) config.config_data['models'][model_name] = { "provider": provider, "model_name": model_name, "temperature": 0.5, "max_tokens": 2048, "top_p": 1.0 } # Create the synthesizer with the model name synthesizer = ReportSynthesizer(model_name=model_name) # Verify the model and provider are set correctly logger.info(f"Synthesizer initialized with model: {synthesizer.model_name}") logger.info(f"Synthesizer provider: {synthesizer.model_config.get('provider')}") # Get completion parameters to verify they're set correctly params = synthesizer._get_completion_params() logger.info(f"Completion parameters: {params}") # Create a simple test message messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Say hello and identify yourself as the model you're running on."} ] try: # Test the generate_completion method logger.info("Testing generate_completion method...") response = await synthesizer.generate_completion(messages) logger.info(f"Response: {response[:100]}...") # Show first 100 chars except Exception as e: logger.error(f"Error generating completion: {e}") # Continue with the next model even if this one fails continue logger.info(f"===== Test completed for {model_name} with provider {provider} =====\n") async def test_provider_selection_stability(): """Test that provider selection remains stable across various scenarios.""" logger.info("\n=== Testing provider selection stability ===") # Test 1: Stability across multiple initializations with the same model logger.info("\nTest 1: Stability across multiple initializations with the same model") model_name = "llama-3.3-70b-versatile" provider = "groq" # Create multiple synthesizers with the same model synthesizers = [] for i in range(3): logger.info(f"Creating synthesizer {i+1} with model {model_name}") synthesizer = ReportSynthesizer(model_name=model_name) synthesizers.append(synthesizer) logger.info(f"Synthesizer {i+1} provider: {synthesizer.model_config.get('provider')}") # Verify all synthesizers have the same provider providers = [s.model_config.get('provider') for s in synthesizers] logger.info(f"Providers across synthesizers: {providers}") assert all(p == provider for p in providers), "Provider not stable across multiple initializations" logger.info("✅ Provider stable across multiple initializations") # Test 2: Stability when switching between models logger.info("\nTest 2: Stability when switching between models") model_configs = [ {"name": "llama-3.3-70b-versatile", "provider": "groq"}, {"name": "gemini-2.0-flash", "provider": "gemini"}, {"name": "claude-3-opus-20240229", "provider": "anthropic"}, {"name": "gpt-4-turbo", "provider": "openai"}, ] # Test switching between models multiple times for _ in range(2): # Do two rounds of switching for model_config in model_configs: model_name = model_config["name"] expected_provider = model_config["provider"] logger.info(f"Switching to model {model_name} with expected provider {expected_provider}") synthesizer = ReportSynthesizer(model_name=model_name) actual_provider = synthesizer.model_config.get('provider') logger.info(f"Model: {model_name}, Expected provider: {expected_provider}, Actual provider: {actual_provider}") assert actual_provider == expected_provider, f"Provider mismatch for {model_name}: expected {expected_provider}, got {actual_provider}" logger.info("✅ Provider selection stable when switching between models") # Test 3: Stability with direct configuration changes logger.info("\nTest 3: Stability with direct configuration changes") test_model = "test-model-stability" # Get the global config instance from config.config import config as global_config # Save original config state original_models = global_config.config_data.get('models', {}).copy() try: # Ensure models dict exists if 'models' not in global_config.config_data: global_config.config_data['models'] = {} # Set up test model with groq provider global_config.config_data['models'][test_model] = { "provider": "groq", "model_name": test_model, "temperature": 0.5, "max_tokens": 2048, "top_p": 1.0 } # Create first synthesizer with groq provider logger.info(f"Creating first synthesizer with {test_model} using groq provider") synthesizer1 = ReportSynthesizer(model_name=test_model) provider1 = synthesizer1.model_config.get('provider') logger.info(f"Initial provider for {test_model}: {provider1}") # Change the provider in the global config global_config.config_data['models'][test_model]["provider"] = "anthropic" # Create second synthesizer with the updated config logger.info(f"Creating second synthesizer with {test_model} using anthropic provider") synthesizer2 = ReportSynthesizer(model_name=test_model) provider2 = synthesizer2.model_config.get('provider') logger.info(f"Updated provider for {test_model}: {provider2}") # Verify the provider was updated assert provider1 == "groq", f"Initial provider should be groq, got {provider1}" assert provider2 == "anthropic", f"Updated provider should be anthropic, got {provider2}" logger.info("✅ Provider selection responds correctly to configuration changes") # Test 4: Provider selection when using singleton vs. creating new instances logger.info("\nTest 4: Provider selection when using singleton vs. creating new instances") from report.report_synthesis import get_report_synthesizer # Set up a test model in the config test_model_singleton = "test-model-singleton" global_config.config_data['models'][test_model_singleton] = { "provider": "openai", "model_name": test_model_singleton, "temperature": 0.7, "max_tokens": 1024 } # Get singleton instance with the test model logger.info(f"Getting singleton instance with {test_model_singleton}") singleton_synthesizer = get_report_synthesizer(model_name=test_model_singleton) singleton_provider = singleton_synthesizer.model_config.get('provider') logger.info(f"Singleton provider: {singleton_provider}") # Create a new instance with the same model logger.info(f"Creating new instance with {test_model_singleton}") new_synthesizer = ReportSynthesizer(model_name=test_model_singleton) new_provider = new_synthesizer.model_config.get('provider') logger.info(f"New instance provider: {new_provider}") # Verify both have the same provider assert singleton_provider == new_provider, f"Provider mismatch between singleton and new instance: {singleton_provider} vs {new_provider}" logger.info("✅ Provider selection consistent between singleton and new instances") # Test 5: Edge case with invalid provider logger.info("\nTest 5: Edge case with invalid provider") # Set up a test model with an invalid provider test_model_invalid = "test-model-invalid-provider" global_config.config_data['models'][test_model_invalid] = { "provider": "invalid_provider", # This provider doesn't exist "model_name": test_model_invalid, "temperature": 0.5 } # Create a synthesizer with the invalid provider model logger.info(f"Creating synthesizer with invalid provider for {test_model_invalid}") invalid_synthesizer = ReportSynthesizer(model_name=test_model_invalid) invalid_provider = invalid_synthesizer.model_config.get('provider') # The provider should remain as specified in the config, even if invalid # This is important for error handling and debugging logger.info(f"Provider for invalid model: {invalid_provider}") assert invalid_provider == "invalid_provider", f"Invalid provider should be preserved, got {invalid_provider}" logger.info("✅ Invalid provider preserved in configuration") # Test 6: Provider fallback mechanism logger.info("\nTest 6: Provider fallback mechanism") # Create a model with no explicit provider test_model_no_provider = "test-model-no-provider" global_config.config_data['models'][test_model_no_provider] = { # No provider specified "model_name": test_model_no_provider, "temperature": 0.5 } # Create a synthesizer with this model logger.info(f"Creating synthesizer with no explicit provider for {test_model_no_provider}") no_provider_synthesizer = ReportSynthesizer(model_name=test_model_no_provider) # The provider should be inferred based on the model name fallback_provider = no_provider_synthesizer.model_config.get('provider') logger.info(f"Fallback provider for model with no explicit provider: {fallback_provider}") # Since our test model name doesn't match any known pattern, it should default to groq assert fallback_provider == "groq", f"Expected fallback to groq, got {fallback_provider}" logger.info("✅ Provider fallback mechanism works correctly") finally: # Restore original config state global_config.config_data['models'] = original_models async def test_provider_selection_after_config_reload(): """Test that provider selection remains stable after config reload.""" logger.info("\n=== Testing provider selection after config reload ===") # Get the global config instance from config.config import config as global_config from config.config import Config # Save original config state original_models = global_config.config_data.get('models', {}).copy() original_config_path = global_config.config_path try: # Set up a test model test_model = "test-model-config-reload" if 'models' not in global_config.config_data: global_config.config_data['models'] = {} global_config.config_data['models'][test_model] = { "provider": "anthropic", "model_name": test_model, "temperature": 0.5 } # Create a synthesizer with this model logger.info(f"Creating synthesizer with {test_model} before config reload") synthesizer_before = ReportSynthesizer(model_name=test_model) provider_before = synthesizer_before.model_config.get('provider') logger.info(f"Provider before reload: {provider_before}") # Simulate config reload by creating a new Config instance logger.info("Simulating config reload...") new_config = Config(config_path=original_config_path) # Add the same test model to the new config if 'models' not in new_config.config_data: new_config.config_data['models'] = {} new_config.config_data['models'][test_model] = { "provider": "anthropic", # Same provider "model_name": test_model, "temperature": 0.5 } # Temporarily replace the global config from config.config import config original_config = config import config.config config.config.config = new_config # Create a new synthesizer after the reload logger.info(f"Creating synthesizer with {test_model} after config reload") synthesizer_after = ReportSynthesizer(model_name=test_model) provider_after = synthesizer_after.model_config.get('provider') logger.info(f"Provider after reload: {provider_after}") # Verify the provider remains the same assert provider_before == provider_after, f"Provider changed after config reload: {provider_before} vs {provider_after}" logger.info("✅ Provider selection stable after config reload") finally: # Restore original config state global_config.config_data['models'] = original_models # Restore original global config if 'original_config' in locals(): config.config.config = original_config async def main(): """Main function to run tests.""" logger.info("Starting report synthesis tests...") await test_model_provider_selection() await test_provider_selection_stability() await test_provider_selection_after_config_reload() logger.info("All tests completed.") if __name__ == "__main__": asyncio.run(main())