Fix model provider selection in report generation to ensure correct provider is used based on config
This commit is contained in:
parent
d76cd9d79b
commit
15357890ea
|
@ -96,7 +96,25 @@ class ReportSynthesizer:
|
||||||
|
|
||||||
def _setup_provider(self) -> None:
|
def _setup_provider(self) -> None:
|
||||||
"""Set up the LLM provider based on the model configuration."""
|
"""Set up the LLM provider based on the model configuration."""
|
||||||
provider = self.model_config.get('provider', 'groq')
|
# Determine provider based on model name if not explicitly specified
|
||||||
|
if 'provider' in self.model_config:
|
||||||
|
provider = self.model_config['provider']
|
||||||
|
elif 'gemini' in self.model_name.lower():
|
||||||
|
provider = 'gemini'
|
||||||
|
# Update model_config to reflect the correct provider
|
||||||
|
self.model_config['provider'] = 'gemini'
|
||||||
|
elif 'claude' in self.model_name.lower():
|
||||||
|
provider = 'anthropic'
|
||||||
|
self.model_config['provider'] = 'anthropic'
|
||||||
|
elif 'gpt' in self.model_name.lower() or 'text-embedding' in self.model_name.lower():
|
||||||
|
provider = 'openai'
|
||||||
|
self.model_config['provider'] = 'openai'
|
||||||
|
elif 'mistral' in self.model_name.lower():
|
||||||
|
provider = 'mistral'
|
||||||
|
self.model_config['provider'] = 'mistral'
|
||||||
|
else:
|
||||||
|
provider = 'groq' # Default to groq for other models
|
||||||
|
self.model_config['provider'] = 'groq'
|
||||||
|
|
||||||
# Log detailed model information for debugging
|
# Log detailed model information for debugging
|
||||||
logger.info(f"Setting up report synthesizer with model: {self.model_name} (provider: {provider})")
|
logger.info(f"Setting up report synthesizer with model: {self.model_name} (provider: {provider})")
|
||||||
|
@ -207,9 +225,12 @@ class ReportSynthesizer:
|
||||||
If stream is False, returns the completion text as a string
|
If stream is False, returns the completion text as a string
|
||||||
If stream is True, returns the completion response object for streaming
|
If stream is True, returns the completion response object for streaming
|
||||||
"""
|
"""
|
||||||
# Get provider from model config
|
# Get provider from model config - this should be set correctly from the config file
|
||||||
provider = self.model_config.get('provider', 'groq').lower()
|
provider = self.model_config.get('provider', 'groq').lower()
|
||||||
|
|
||||||
|
# Log the provider and model being used for debugging
|
||||||
|
logger.info(f"Using provider: {provider} with model: {self.model_name}")
|
||||||
|
|
||||||
# Special handling for Gemini models - they use 'user' and 'model' roles
|
# Special handling for Gemini models - they use 'user' and 'model' roles
|
||||||
if provider == 'gemini':
|
if provider == 'gemini':
|
||||||
formatted_messages = []
|
formatted_messages = []
|
||||||
|
@ -236,11 +257,29 @@ class ReportSynthesizer:
|
||||||
# Get completion parameters
|
# Get completion parameters
|
||||||
params = self._get_completion_params()
|
params = self._get_completion_params()
|
||||||
|
|
||||||
# Double-check that we're using the correct model
|
# Use the provider from the model config - this should be set correctly from the config file
|
||||||
|
# Different providers have different formatting requirements for the model parameter
|
||||||
if provider == 'gemini':
|
if provider == 'gemini':
|
||||||
params['model'] = f"gemini/{self.model_name}"
|
# For Gemini models, we use the vertex_ai provider in litellm
|
||||||
|
params['model'] = self.model_name
|
||||||
|
params['custom_llm_provider'] = 'vertex_ai'
|
||||||
elif provider == 'groq':
|
elif provider == 'groq':
|
||||||
params['model'] = f"groq/{self.model_name}"
|
# For Groq models, we need to prefix with provider
|
||||||
|
if not self.model_name.startswith('groq/'):
|
||||||
|
params['model'] = f"groq/{self.model_name}"
|
||||||
|
else:
|
||||||
|
params['model'] = self.model_name
|
||||||
|
elif provider == 'anthropic':
|
||||||
|
# For Claude models
|
||||||
|
params['model'] = self.model_name
|
||||||
|
params['custom_llm_provider'] = 'anthropic'
|
||||||
|
elif provider == 'openai':
|
||||||
|
# For OpenAI models
|
||||||
|
params['model'] = self.model_name
|
||||||
|
params['custom_llm_provider'] = 'openai'
|
||||||
|
|
||||||
|
# Log the final model parameter being used
|
||||||
|
logger.info(f"Final model parameter: {params.get('model', 'unknown')} with provider: {provider}")
|
||||||
|
|
||||||
# Log the actual parameters being used for the LLM call
|
# Log the actual parameters being used for the LLM call
|
||||||
safe_params = params.copy()
|
safe_params = params.copy()
|
||||||
|
|
Loading…
Reference in New Issue