diff --git a/report/report_synthesis.py b/report/report_synthesis.py index 513e9e0..c57e396 100644 --- a/report/report_synthesis.py +++ b/report/report_synthesis.py @@ -96,7 +96,25 @@ class ReportSynthesizer: def _setup_provider(self) -> None: """Set up the LLM provider based on the model configuration.""" - provider = self.model_config.get('provider', 'groq') + # Determine provider based on model name if not explicitly specified + if 'provider' in self.model_config: + provider = self.model_config['provider'] + elif 'gemini' in self.model_name.lower(): + provider = 'gemini' + # Update model_config to reflect the correct provider + self.model_config['provider'] = 'gemini' + elif 'claude' in self.model_name.lower(): + provider = 'anthropic' + self.model_config['provider'] = 'anthropic' + elif 'gpt' in self.model_name.lower() or 'text-embedding' in self.model_name.lower(): + provider = 'openai' + self.model_config['provider'] = 'openai' + elif 'mistral' in self.model_name.lower(): + provider = 'mistral' + self.model_config['provider'] = 'mistral' + else: + provider = 'groq' # Default to groq for other models + self.model_config['provider'] = 'groq' # Log detailed model information for debugging logger.info(f"Setting up report synthesizer with model: {self.model_name} (provider: {provider})") @@ -207,9 +225,12 @@ class ReportSynthesizer: If stream is False, returns the completion text as a string If stream is True, returns the completion response object for streaming """ - # Get provider from model config + # Get provider from model config - this should be set correctly from the config file provider = self.model_config.get('provider', 'groq').lower() + # Log the provider and model being used for debugging + logger.info(f"Using provider: {provider} with model: {self.model_name}") + # Special handling for Gemini models - they use 'user' and 'model' roles if provider == 'gemini': formatted_messages = [] @@ -236,11 +257,29 @@ class ReportSynthesizer: # Get completion parameters params = self._get_completion_params() - # Double-check that we're using the correct model + # Use the provider from the model config - this should be set correctly from the config file + # Different providers have different formatting requirements for the model parameter if provider == 'gemini': - params['model'] = f"gemini/{self.model_name}" + # For Gemini models, we use the vertex_ai provider in litellm + params['model'] = self.model_name + params['custom_llm_provider'] = 'vertex_ai' elif provider == 'groq': - params['model'] = f"groq/{self.model_name}" + # For Groq models, we need to prefix with provider + if not self.model_name.startswith('groq/'): + params['model'] = f"groq/{self.model_name}" + else: + params['model'] = self.model_name + elif provider == 'anthropic': + # For Claude models + params['model'] = self.model_name + params['custom_llm_provider'] = 'anthropic' + elif provider == 'openai': + # For OpenAI models + params['model'] = self.model_name + params['custom_llm_provider'] = 'openai' + + # Log the final model parameter being used + logger.info(f"Final model parameter: {params.get('model', 'unknown')} with provider: {provider}") # Log the actual parameters being used for the LLM call safe_params = params.copy()