ira/report/report_synthesis_test.py

332 lines
15 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Test script for the report synthesis module, specifically to verify
model provider selection works correctly.
"""
import os
import sys
import asyncio
import logging
from typing import Dict, Any, List
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Add parent directory to path to import modules
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config.config import Config
from report.report_synthesis import ReportSynthesizer
async def test_model_provider_selection():
"""Test that model provider selection works correctly."""
logger.info("=== Testing basic model provider selection ===")
# Initialize config
config = Config()
# Test with different models and providers
models_to_test = [
{"provider": "groq", "model_name": "llama-3.3-70b-versatile"},
{"provider": "gemini", "model_name": "gemini-2.0-flash"},
{"provider": "anthropic", "model_name": "claude-3-opus-20240229"},
{"provider": "openai", "model_name": "gpt-4-turbo"},
]
for model_config in models_to_test:
provider = model_config["provider"]
model_name = model_config["model_name"]
logger.info(f"\n\n===== Testing model: {model_name} with provider: {provider} =====")
# Create a synthesizer with the specified model
# First update the config to use the specified provider
config.config_data['models'] = config.config_data.get('models', {})
config.config_data['models'][model_name] = {
"provider": provider,
"model_name": model_name,
"temperature": 0.5,
"max_tokens": 2048,
"top_p": 1.0
}
# Create the synthesizer with the model name
synthesizer = ReportSynthesizer(model_name=model_name)
# Verify the model and provider are set correctly
logger.info(f"Synthesizer initialized with model: {synthesizer.model_name}")
logger.info(f"Synthesizer provider: {synthesizer.model_config.get('provider')}")
# Get completion parameters to verify they're set correctly
params = synthesizer._get_completion_params()
logger.info(f"Completion parameters: {params}")
# Create a simple test message
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Say hello and identify yourself as the model you're running on."}
]
try:
# Test the generate_completion method
logger.info("Testing generate_completion method...")
response = await synthesizer.generate_completion(messages)
logger.info(f"Response: {response[:100]}...") # Show first 100 chars
except Exception as e:
logger.error(f"Error generating completion: {e}")
# Continue with the next model even if this one fails
continue
logger.info(f"===== Test completed for {model_name} with provider {provider} =====\n")
async def test_provider_selection_stability():
"""Test that provider selection remains stable across various scenarios."""
logger.info("\n=== Testing provider selection stability ===")
# Test 1: Stability across multiple initializations with the same model
logger.info("\nTest 1: Stability across multiple initializations with the same model")
model_name = "llama-3.3-70b-versatile"
provider = "groq"
# Create multiple synthesizers with the same model
synthesizers = []
for i in range(3):
logger.info(f"Creating synthesizer {i+1} with model {model_name}")
synthesizer = ReportSynthesizer(model_name=model_name)
synthesizers.append(synthesizer)
logger.info(f"Synthesizer {i+1} provider: {synthesizer.model_config.get('provider')}")
# Verify all synthesizers have the same provider
providers = [s.model_config.get('provider') for s in synthesizers]
logger.info(f"Providers across synthesizers: {providers}")
assert all(p == provider for p in providers), "Provider not stable across multiple initializations"
logger.info("✅ Provider stable across multiple initializations")
# Test 2: Stability when switching between models
logger.info("\nTest 2: Stability when switching between models")
model_configs = [
{"name": "llama-3.3-70b-versatile", "provider": "groq"},
{"name": "gemini-2.0-flash", "provider": "gemini"},
{"name": "claude-3-opus-20240229", "provider": "anthropic"},
{"name": "gpt-4-turbo", "provider": "openai"},
]
# Test switching between models multiple times
for _ in range(2): # Do two rounds of switching
for model_config in model_configs:
model_name = model_config["name"]
expected_provider = model_config["provider"]
logger.info(f"Switching to model {model_name} with expected provider {expected_provider}")
synthesizer = ReportSynthesizer(model_name=model_name)
actual_provider = synthesizer.model_config.get('provider')
logger.info(f"Model: {model_name}, Expected provider: {expected_provider}, Actual provider: {actual_provider}")
assert actual_provider == expected_provider, f"Provider mismatch for {model_name}: expected {expected_provider}, got {actual_provider}"
logger.info("✅ Provider selection stable when switching between models")
# Test 3: Stability with direct configuration changes
logger.info("\nTest 3: Stability with direct configuration changes")
test_model = "test-model-stability"
# Get the global config instance
from config.config import config as global_config
# Save original config state
original_models = global_config.config_data.get('models', {}).copy()
try:
# Ensure models dict exists
if 'models' not in global_config.config_data:
global_config.config_data['models'] = {}
# Set up test model with groq provider
global_config.config_data['models'][test_model] = {
"provider": "groq",
"model_name": test_model,
"temperature": 0.5,
"max_tokens": 2048,
"top_p": 1.0
}
# Create first synthesizer with groq provider
logger.info(f"Creating first synthesizer with {test_model} using groq provider")
synthesizer1 = ReportSynthesizer(model_name=test_model)
provider1 = synthesizer1.model_config.get('provider')
logger.info(f"Initial provider for {test_model}: {provider1}")
# Change the provider in the global config
global_config.config_data['models'][test_model]["provider"] = "anthropic"
# Create second synthesizer with the updated config
logger.info(f"Creating second synthesizer with {test_model} using anthropic provider")
synthesizer2 = ReportSynthesizer(model_name=test_model)
provider2 = synthesizer2.model_config.get('provider')
logger.info(f"Updated provider for {test_model}: {provider2}")
# Verify the provider was updated
assert provider1 == "groq", f"Initial provider should be groq, got {provider1}"
assert provider2 == "anthropic", f"Updated provider should be anthropic, got {provider2}"
logger.info("✅ Provider selection responds correctly to configuration changes")
# Test 4: Provider selection when using singleton vs. creating new instances
logger.info("\nTest 4: Provider selection when using singleton vs. creating new instances")
from report.report_synthesis import get_report_synthesizer
# Set up a test model in the config
test_model_singleton = "test-model-singleton"
global_config.config_data['models'][test_model_singleton] = {
"provider": "openai",
"model_name": test_model_singleton,
"temperature": 0.7,
"max_tokens": 1024
}
# Get singleton instance with the test model
logger.info(f"Getting singleton instance with {test_model_singleton}")
singleton_synthesizer = get_report_synthesizer(model_name=test_model_singleton)
singleton_provider = singleton_synthesizer.model_config.get('provider')
logger.info(f"Singleton provider: {singleton_provider}")
# Create a new instance with the same model
logger.info(f"Creating new instance with {test_model_singleton}")
new_synthesizer = ReportSynthesizer(model_name=test_model_singleton)
new_provider = new_synthesizer.model_config.get('provider')
logger.info(f"New instance provider: {new_provider}")
# Verify both have the same provider
assert singleton_provider == new_provider, f"Provider mismatch between singleton and new instance: {singleton_provider} vs {new_provider}"
logger.info("✅ Provider selection consistent between singleton and new instances")
# Test 5: Edge case with invalid provider
logger.info("\nTest 5: Edge case with invalid provider")
# Set up a test model with an invalid provider
test_model_invalid = "test-model-invalid-provider"
global_config.config_data['models'][test_model_invalid] = {
"provider": "invalid_provider", # This provider doesn't exist
"model_name": test_model_invalid,
"temperature": 0.5
}
# Create a synthesizer with the invalid provider model
logger.info(f"Creating synthesizer with invalid provider for {test_model_invalid}")
invalid_synthesizer = ReportSynthesizer(model_name=test_model_invalid)
invalid_provider = invalid_synthesizer.model_config.get('provider')
# The provider should remain as specified in the config, even if invalid
# This is important for error handling and debugging
logger.info(f"Provider for invalid model: {invalid_provider}")
assert invalid_provider == "invalid_provider", f"Invalid provider should be preserved, got {invalid_provider}"
logger.info("✅ Invalid provider preserved in configuration")
# Test 6: Provider fallback mechanism
logger.info("\nTest 6: Provider fallback mechanism")
# Create a model with no explicit provider
test_model_no_provider = "test-model-no-provider"
global_config.config_data['models'][test_model_no_provider] = {
# No provider specified
"model_name": test_model_no_provider,
"temperature": 0.5
}
# Create a synthesizer with this model
logger.info(f"Creating synthesizer with no explicit provider for {test_model_no_provider}")
no_provider_synthesizer = ReportSynthesizer(model_name=test_model_no_provider)
# The provider should be inferred based on the model name
fallback_provider = no_provider_synthesizer.model_config.get('provider')
logger.info(f"Fallback provider for model with no explicit provider: {fallback_provider}")
# Since our test model name doesn't match any known pattern, it should default to groq
assert fallback_provider == "groq", f"Expected fallback to groq, got {fallback_provider}"
logger.info("✅ Provider fallback mechanism works correctly")
finally:
# Restore original config state
global_config.config_data['models'] = original_models
async def test_provider_selection_after_config_reload():
"""Test that provider selection remains stable after config reload."""
logger.info("\n=== Testing provider selection after config reload ===")
# Get the global config instance
from config.config import config as global_config
from config.config import Config
# Save original config state
original_models = global_config.config_data.get('models', {}).copy()
original_config_path = global_config.config_path
try:
# Set up a test model
test_model = "test-model-config-reload"
if 'models' not in global_config.config_data:
global_config.config_data['models'] = {}
global_config.config_data['models'][test_model] = {
"provider": "anthropic",
"model_name": test_model,
"temperature": 0.5
}
# Create a synthesizer with this model
logger.info(f"Creating synthesizer with {test_model} before config reload")
synthesizer_before = ReportSynthesizer(model_name=test_model)
provider_before = synthesizer_before.model_config.get('provider')
logger.info(f"Provider before reload: {provider_before}")
# Simulate config reload by creating a new Config instance
logger.info("Simulating config reload...")
new_config = Config(config_path=original_config_path)
# Add the same test model to the new config
if 'models' not in new_config.config_data:
new_config.config_data['models'] = {}
new_config.config_data['models'][test_model] = {
"provider": "anthropic", # Same provider
"model_name": test_model,
"temperature": 0.5
}
# Temporarily replace the global config
from config.config import config
original_config = config
import config.config
config.config.config = new_config
# Create a new synthesizer after the reload
logger.info(f"Creating synthesizer with {test_model} after config reload")
synthesizer_after = ReportSynthesizer(model_name=test_model)
provider_after = synthesizer_after.model_config.get('provider')
logger.info(f"Provider after reload: {provider_after}")
# Verify the provider remains the same
assert provider_before == provider_after, f"Provider changed after config reload: {provider_before} vs {provider_after}"
logger.info("✅ Provider selection stable after config reload")
finally:
# Restore original config state
global_config.config_data['models'] = original_models
# Restore original global config
if 'original_config' in locals():
config.config.config = original_config
async def main():
"""Main function to run tests."""
logger.info("Starting report synthesis tests...")
await test_model_provider_selection()
await test_provider_selection_stability()
await test_provider_selection_after_config_reload()
logger.info("All tests completed.")
if __name__ == "__main__":
asyncio.run(main())