332 lines
15 KiB
Python
332 lines
15 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
Test script for the report synthesis module, specifically to verify
|
|
model provider selection works correctly.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import asyncio
|
|
import logging
|
|
from typing import Dict, Any, List
|
|
|
|
# Set up logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Add parent directory to path to import modules
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from config.config import Config
|
|
from report.report_synthesis import ReportSynthesizer
|
|
|
|
async def test_model_provider_selection():
|
|
"""Test that model provider selection works correctly."""
|
|
logger.info("=== Testing basic model provider selection ===")
|
|
# Initialize config
|
|
config = Config()
|
|
|
|
# Test with different models and providers
|
|
models_to_test = [
|
|
{"provider": "groq", "model_name": "llama-3.3-70b-versatile"},
|
|
{"provider": "gemini", "model_name": "gemini-2.0-flash"},
|
|
{"provider": "anthropic", "model_name": "claude-3-opus-20240229"},
|
|
{"provider": "openai", "model_name": "gpt-4-turbo"},
|
|
]
|
|
|
|
for model_config in models_to_test:
|
|
provider = model_config["provider"]
|
|
model_name = model_config["model_name"]
|
|
|
|
logger.info(f"\n\n===== Testing model: {model_name} with provider: {provider} =====")
|
|
|
|
# Create a synthesizer with the specified model
|
|
# First update the config to use the specified provider
|
|
config.config_data['models'] = config.config_data.get('models', {})
|
|
config.config_data['models'][model_name] = {
|
|
"provider": provider,
|
|
"model_name": model_name,
|
|
"temperature": 0.5,
|
|
"max_tokens": 2048,
|
|
"top_p": 1.0
|
|
}
|
|
|
|
# Create the synthesizer with the model name
|
|
synthesizer = ReportSynthesizer(model_name=model_name)
|
|
|
|
# Verify the model and provider are set correctly
|
|
logger.info(f"Synthesizer initialized with model: {synthesizer.model_name}")
|
|
logger.info(f"Synthesizer provider: {synthesizer.model_config.get('provider')}")
|
|
|
|
# Get completion parameters to verify they're set correctly
|
|
params = synthesizer._get_completion_params()
|
|
logger.info(f"Completion parameters: {params}")
|
|
|
|
# Create a simple test message
|
|
messages = [
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "user", "content": "Say hello and identify yourself as the model you're running on."}
|
|
]
|
|
|
|
try:
|
|
# Test the generate_completion method
|
|
logger.info("Testing generate_completion method...")
|
|
response = await synthesizer.generate_completion(messages)
|
|
logger.info(f"Response: {response[:100]}...") # Show first 100 chars
|
|
except Exception as e:
|
|
logger.error(f"Error generating completion: {e}")
|
|
# Continue with the next model even if this one fails
|
|
continue
|
|
|
|
logger.info(f"===== Test completed for {model_name} with provider {provider} =====\n")
|
|
|
|
async def test_provider_selection_stability():
|
|
"""Test that provider selection remains stable across various scenarios."""
|
|
logger.info("\n=== Testing provider selection stability ===")
|
|
|
|
# Test 1: Stability across multiple initializations with the same model
|
|
logger.info("\nTest 1: Stability across multiple initializations with the same model")
|
|
model_name = "llama-3.3-70b-versatile"
|
|
provider = "groq"
|
|
|
|
# Create multiple synthesizers with the same model
|
|
synthesizers = []
|
|
for i in range(3):
|
|
logger.info(f"Creating synthesizer {i+1} with model {model_name}")
|
|
synthesizer = ReportSynthesizer(model_name=model_name)
|
|
synthesizers.append(synthesizer)
|
|
logger.info(f"Synthesizer {i+1} provider: {synthesizer.model_config.get('provider')}")
|
|
|
|
# Verify all synthesizers have the same provider
|
|
providers = [s.model_config.get('provider') for s in synthesizers]
|
|
logger.info(f"Providers across synthesizers: {providers}")
|
|
assert all(p == provider for p in providers), "Provider not stable across multiple initializations"
|
|
logger.info("✅ Provider stable across multiple initializations")
|
|
|
|
# Test 2: Stability when switching between models
|
|
logger.info("\nTest 2: Stability when switching between models")
|
|
model_configs = [
|
|
{"name": "llama-3.3-70b-versatile", "provider": "groq"},
|
|
{"name": "gemini-2.0-flash", "provider": "gemini"},
|
|
{"name": "claude-3-opus-20240229", "provider": "anthropic"},
|
|
{"name": "gpt-4-turbo", "provider": "openai"},
|
|
]
|
|
|
|
# Test switching between models multiple times
|
|
for _ in range(2): # Do two rounds of switching
|
|
for model_config in model_configs:
|
|
model_name = model_config["name"]
|
|
expected_provider = model_config["provider"]
|
|
|
|
logger.info(f"Switching to model {model_name} with expected provider {expected_provider}")
|
|
synthesizer = ReportSynthesizer(model_name=model_name)
|
|
actual_provider = synthesizer.model_config.get('provider')
|
|
|
|
logger.info(f"Model: {model_name}, Expected provider: {expected_provider}, Actual provider: {actual_provider}")
|
|
assert actual_provider == expected_provider, f"Provider mismatch for {model_name}: expected {expected_provider}, got {actual_provider}"
|
|
|
|
logger.info("✅ Provider selection stable when switching between models")
|
|
|
|
# Test 3: Stability with direct configuration changes
|
|
logger.info("\nTest 3: Stability with direct configuration changes")
|
|
test_model = "test-model-stability"
|
|
|
|
# Get the global config instance
|
|
from config.config import config as global_config
|
|
|
|
# Save original config state
|
|
original_models = global_config.config_data.get('models', {}).copy()
|
|
|
|
try:
|
|
# Ensure models dict exists
|
|
if 'models' not in global_config.config_data:
|
|
global_config.config_data['models'] = {}
|
|
|
|
# Set up test model with groq provider
|
|
global_config.config_data['models'][test_model] = {
|
|
"provider": "groq",
|
|
"model_name": test_model,
|
|
"temperature": 0.5,
|
|
"max_tokens": 2048,
|
|
"top_p": 1.0
|
|
}
|
|
|
|
# Create first synthesizer with groq provider
|
|
logger.info(f"Creating first synthesizer with {test_model} using groq provider")
|
|
synthesizer1 = ReportSynthesizer(model_name=test_model)
|
|
provider1 = synthesizer1.model_config.get('provider')
|
|
logger.info(f"Initial provider for {test_model}: {provider1}")
|
|
|
|
# Change the provider in the global config
|
|
global_config.config_data['models'][test_model]["provider"] = "anthropic"
|
|
|
|
# Create second synthesizer with the updated config
|
|
logger.info(f"Creating second synthesizer with {test_model} using anthropic provider")
|
|
synthesizer2 = ReportSynthesizer(model_name=test_model)
|
|
provider2 = synthesizer2.model_config.get('provider')
|
|
logger.info(f"Updated provider for {test_model}: {provider2}")
|
|
|
|
# Verify the provider was updated
|
|
assert provider1 == "groq", f"Initial provider should be groq, got {provider1}"
|
|
assert provider2 == "anthropic", f"Updated provider should be anthropic, got {provider2}"
|
|
logger.info("✅ Provider selection responds correctly to configuration changes")
|
|
|
|
# Test 4: Provider selection when using singleton vs. creating new instances
|
|
logger.info("\nTest 4: Provider selection when using singleton vs. creating new instances")
|
|
|
|
from report.report_synthesis import get_report_synthesizer
|
|
|
|
# Set up a test model in the config
|
|
test_model_singleton = "test-model-singleton"
|
|
global_config.config_data['models'][test_model_singleton] = {
|
|
"provider": "openai",
|
|
"model_name": test_model_singleton,
|
|
"temperature": 0.7,
|
|
"max_tokens": 1024
|
|
}
|
|
|
|
# Get singleton instance with the test model
|
|
logger.info(f"Getting singleton instance with {test_model_singleton}")
|
|
singleton_synthesizer = get_report_synthesizer(model_name=test_model_singleton)
|
|
singleton_provider = singleton_synthesizer.model_config.get('provider')
|
|
logger.info(f"Singleton provider: {singleton_provider}")
|
|
|
|
# Create a new instance with the same model
|
|
logger.info(f"Creating new instance with {test_model_singleton}")
|
|
new_synthesizer = ReportSynthesizer(model_name=test_model_singleton)
|
|
new_provider = new_synthesizer.model_config.get('provider')
|
|
logger.info(f"New instance provider: {new_provider}")
|
|
|
|
# Verify both have the same provider
|
|
assert singleton_provider == new_provider, f"Provider mismatch between singleton and new instance: {singleton_provider} vs {new_provider}"
|
|
logger.info("✅ Provider selection consistent between singleton and new instances")
|
|
|
|
# Test 5: Edge case with invalid provider
|
|
logger.info("\nTest 5: Edge case with invalid provider")
|
|
|
|
# Set up a test model with an invalid provider
|
|
test_model_invalid = "test-model-invalid-provider"
|
|
global_config.config_data['models'][test_model_invalid] = {
|
|
"provider": "invalid_provider", # This provider doesn't exist
|
|
"model_name": test_model_invalid,
|
|
"temperature": 0.5
|
|
}
|
|
|
|
# Create a synthesizer with the invalid provider model
|
|
logger.info(f"Creating synthesizer with invalid provider for {test_model_invalid}")
|
|
invalid_synthesizer = ReportSynthesizer(model_name=test_model_invalid)
|
|
invalid_provider = invalid_synthesizer.model_config.get('provider')
|
|
|
|
# The provider should remain as specified in the config, even if invalid
|
|
# This is important for error handling and debugging
|
|
logger.info(f"Provider for invalid model: {invalid_provider}")
|
|
assert invalid_provider == "invalid_provider", f"Invalid provider should be preserved, got {invalid_provider}"
|
|
logger.info("✅ Invalid provider preserved in configuration")
|
|
|
|
# Test 6: Provider fallback mechanism
|
|
logger.info("\nTest 6: Provider fallback mechanism")
|
|
|
|
# Create a model with no explicit provider
|
|
test_model_no_provider = "test-model-no-provider"
|
|
global_config.config_data['models'][test_model_no_provider] = {
|
|
# No provider specified
|
|
"model_name": test_model_no_provider,
|
|
"temperature": 0.5
|
|
}
|
|
|
|
# Create a synthesizer with this model
|
|
logger.info(f"Creating synthesizer with no explicit provider for {test_model_no_provider}")
|
|
no_provider_synthesizer = ReportSynthesizer(model_name=test_model_no_provider)
|
|
|
|
# The provider should be inferred based on the model name
|
|
fallback_provider = no_provider_synthesizer.model_config.get('provider')
|
|
logger.info(f"Fallback provider for model with no explicit provider: {fallback_provider}")
|
|
|
|
# Since our test model name doesn't match any known pattern, it should default to groq
|
|
assert fallback_provider == "groq", f"Expected fallback to groq, got {fallback_provider}"
|
|
logger.info("✅ Provider fallback mechanism works correctly")
|
|
|
|
finally:
|
|
# Restore original config state
|
|
global_config.config_data['models'] = original_models
|
|
|
|
async def test_provider_selection_after_config_reload():
|
|
"""Test that provider selection remains stable after config reload."""
|
|
logger.info("\n=== Testing provider selection after config reload ===")
|
|
|
|
# Get the global config instance
|
|
from config.config import config as global_config
|
|
from config.config import Config
|
|
|
|
# Save original config state
|
|
original_models = global_config.config_data.get('models', {}).copy()
|
|
original_config_path = global_config.config_path
|
|
|
|
try:
|
|
# Set up a test model
|
|
test_model = "test-model-config-reload"
|
|
if 'models' not in global_config.config_data:
|
|
global_config.config_data['models'] = {}
|
|
|
|
global_config.config_data['models'][test_model] = {
|
|
"provider": "anthropic",
|
|
"model_name": test_model,
|
|
"temperature": 0.5
|
|
}
|
|
|
|
# Create a synthesizer with this model
|
|
logger.info(f"Creating synthesizer with {test_model} before config reload")
|
|
synthesizer_before = ReportSynthesizer(model_name=test_model)
|
|
provider_before = synthesizer_before.model_config.get('provider')
|
|
logger.info(f"Provider before reload: {provider_before}")
|
|
|
|
# Simulate config reload by creating a new Config instance
|
|
logger.info("Simulating config reload...")
|
|
new_config = Config(config_path=original_config_path)
|
|
|
|
# Add the same test model to the new config
|
|
if 'models' not in new_config.config_data:
|
|
new_config.config_data['models'] = {}
|
|
|
|
new_config.config_data['models'][test_model] = {
|
|
"provider": "anthropic", # Same provider
|
|
"model_name": test_model,
|
|
"temperature": 0.5
|
|
}
|
|
|
|
# Temporarily replace the global config
|
|
from config.config import config
|
|
original_config = config
|
|
import config.config
|
|
config.config.config = new_config
|
|
|
|
# Create a new synthesizer after the reload
|
|
logger.info(f"Creating synthesizer with {test_model} after config reload")
|
|
synthesizer_after = ReportSynthesizer(model_name=test_model)
|
|
provider_after = synthesizer_after.model_config.get('provider')
|
|
logger.info(f"Provider after reload: {provider_after}")
|
|
|
|
# Verify the provider remains the same
|
|
assert provider_before == provider_after, f"Provider changed after config reload: {provider_before} vs {provider_after}"
|
|
logger.info("✅ Provider selection stable after config reload")
|
|
|
|
finally:
|
|
# Restore original config state
|
|
global_config.config_data['models'] = original_models
|
|
# Restore original global config
|
|
if 'original_config' in locals():
|
|
config.config.config = original_config
|
|
|
|
async def main():
|
|
"""Main function to run tests."""
|
|
logger.info("Starting report synthesis tests...")
|
|
await test_model_provider_selection()
|
|
await test_provider_selection_stability()
|
|
await test_provider_selection_after_config_reload()
|
|
logger.info("All tests completed.")
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|