Fix model provider selection in report generation to ensure correct provider is used based on config
This commit is contained in:
parent
d76cd9d79b
commit
15357890ea
|
@ -96,7 +96,25 @@ class ReportSynthesizer:
|
|||
|
||||
def _setup_provider(self) -> None:
|
||||
"""Set up the LLM provider based on the model configuration."""
|
||||
provider = self.model_config.get('provider', 'groq')
|
||||
# Determine provider based on model name if not explicitly specified
|
||||
if 'provider' in self.model_config:
|
||||
provider = self.model_config['provider']
|
||||
elif 'gemini' in self.model_name.lower():
|
||||
provider = 'gemini'
|
||||
# Update model_config to reflect the correct provider
|
||||
self.model_config['provider'] = 'gemini'
|
||||
elif 'claude' in self.model_name.lower():
|
||||
provider = 'anthropic'
|
||||
self.model_config['provider'] = 'anthropic'
|
||||
elif 'gpt' in self.model_name.lower() or 'text-embedding' in self.model_name.lower():
|
||||
provider = 'openai'
|
||||
self.model_config['provider'] = 'openai'
|
||||
elif 'mistral' in self.model_name.lower():
|
||||
provider = 'mistral'
|
||||
self.model_config['provider'] = 'mistral'
|
||||
else:
|
||||
provider = 'groq' # Default to groq for other models
|
||||
self.model_config['provider'] = 'groq'
|
||||
|
||||
# Log detailed model information for debugging
|
||||
logger.info(f"Setting up report synthesizer with model: {self.model_name} (provider: {provider})")
|
||||
|
@ -207,9 +225,12 @@ class ReportSynthesizer:
|
|||
If stream is False, returns the completion text as a string
|
||||
If stream is True, returns the completion response object for streaming
|
||||
"""
|
||||
# Get provider from model config
|
||||
# Get provider from model config - this should be set correctly from the config file
|
||||
provider = self.model_config.get('provider', 'groq').lower()
|
||||
|
||||
# Log the provider and model being used for debugging
|
||||
logger.info(f"Using provider: {provider} with model: {self.model_name}")
|
||||
|
||||
# Special handling for Gemini models - they use 'user' and 'model' roles
|
||||
if provider == 'gemini':
|
||||
formatted_messages = []
|
||||
|
@ -236,11 +257,29 @@ class ReportSynthesizer:
|
|||
# Get completion parameters
|
||||
params = self._get_completion_params()
|
||||
|
||||
# Double-check that we're using the correct model
|
||||
# Use the provider from the model config - this should be set correctly from the config file
|
||||
# Different providers have different formatting requirements for the model parameter
|
||||
if provider == 'gemini':
|
||||
params['model'] = f"gemini/{self.model_name}"
|
||||
# For Gemini models, we use the vertex_ai provider in litellm
|
||||
params['model'] = self.model_name
|
||||
params['custom_llm_provider'] = 'vertex_ai'
|
||||
elif provider == 'groq':
|
||||
params['model'] = f"groq/{self.model_name}"
|
||||
# For Groq models, we need to prefix with provider
|
||||
if not self.model_name.startswith('groq/'):
|
||||
params['model'] = f"groq/{self.model_name}"
|
||||
else:
|
||||
params['model'] = self.model_name
|
||||
elif provider == 'anthropic':
|
||||
# For Claude models
|
||||
params['model'] = self.model_name
|
||||
params['custom_llm_provider'] = 'anthropic'
|
||||
elif provider == 'openai':
|
||||
# For OpenAI models
|
||||
params['model'] = self.model_name
|
||||
params['custom_llm_provider'] = 'openai'
|
||||
|
||||
# Log the final model parameter being used
|
||||
logger.info(f"Final model parameter: {params.get('model', 'unknown')} with provider: {provider}")
|
||||
|
||||
# Log the actual parameters being used for the LLM call
|
||||
safe_params = params.copy()
|
||||
|
|
Loading…
Reference in New Issue