Fix model provider selection in report generation to ensure correct provider is used based on config

This commit is contained in:
Steve White 2025-03-18 17:57:09 -05:00
parent d76cd9d79b
commit 15357890ea
1 changed files with 44 additions and 5 deletions

View File

@ -96,7 +96,25 @@ class ReportSynthesizer:
def _setup_provider(self) -> None:
"""Set up the LLM provider based on the model configuration."""
provider = self.model_config.get('provider', 'groq')
# Determine provider based on model name if not explicitly specified
if 'provider' in self.model_config:
provider = self.model_config['provider']
elif 'gemini' in self.model_name.lower():
provider = 'gemini'
# Update model_config to reflect the correct provider
self.model_config['provider'] = 'gemini'
elif 'claude' in self.model_name.lower():
provider = 'anthropic'
self.model_config['provider'] = 'anthropic'
elif 'gpt' in self.model_name.lower() or 'text-embedding' in self.model_name.lower():
provider = 'openai'
self.model_config['provider'] = 'openai'
elif 'mistral' in self.model_name.lower():
provider = 'mistral'
self.model_config['provider'] = 'mistral'
else:
provider = 'groq' # Default to groq for other models
self.model_config['provider'] = 'groq'
# Log detailed model information for debugging
logger.info(f"Setting up report synthesizer with model: {self.model_name} (provider: {provider})")
@ -207,9 +225,12 @@ class ReportSynthesizer:
If stream is False, returns the completion text as a string
If stream is True, returns the completion response object for streaming
"""
# Get provider from model config
# Get provider from model config - this should be set correctly from the config file
provider = self.model_config.get('provider', 'groq').lower()
# Log the provider and model being used for debugging
logger.info(f"Using provider: {provider} with model: {self.model_name}")
# Special handling for Gemini models - they use 'user' and 'model' roles
if provider == 'gemini':
formatted_messages = []
@ -236,11 +257,29 @@ class ReportSynthesizer:
# Get completion parameters
params = self._get_completion_params()
# Double-check that we're using the correct model
# Use the provider from the model config - this should be set correctly from the config file
# Different providers have different formatting requirements for the model parameter
if provider == 'gemini':
params['model'] = f"gemini/{self.model_name}"
# For Gemini models, we use the vertex_ai provider in litellm
params['model'] = self.model_name
params['custom_llm_provider'] = 'vertex_ai'
elif provider == 'groq':
params['model'] = f"groq/{self.model_name}"
# For Groq models, we need to prefix with provider
if not self.model_name.startswith('groq/'):
params['model'] = f"groq/{self.model_name}"
else:
params['model'] = self.model_name
elif provider == 'anthropic':
# For Claude models
params['model'] = self.model_name
params['custom_llm_provider'] = 'anthropic'
elif provider == 'openai':
# For OpenAI models
params['model'] = self.model_name
params['custom_llm_provider'] = 'openai'
# Log the final model parameter being used
logger.info(f"Final model parameter: {params.get('model', 'unknown')} with provider: {provider}")
# Log the actual parameters being used for the LLM call
safe_params = params.copy()