From d4beb73a7ae62ff3e804075a7e6f27bb5c9a1826 Mon Sep 17 00:00:00 2001 From: Steve White Date: Fri, 28 Feb 2025 15:25:33 -0600 Subject: [PATCH] Update Gemini integration to use correct LiteLLM format --- config/config.py | 50 ++++++++++++++++++++------------------ config/config.yaml | 7 +++--- query/llm_interface.py | 19 ++++++++++++--- report/report_synthesis.py | 19 ++++++++++++--- 4 files changed, 59 insertions(+), 36 deletions(-) diff --git a/config/config.py b/config/config.py index f7174db..085a41b 100644 --- a/config/config.py +++ b/config/config.py @@ -47,45 +47,47 @@ class Config: print(f"Error loading configuration: {e}") def get_api_key(self, provider: str) -> str: - """ - Get API key for the specified provider. + """Get the API key for a specific provider. Args: - provider: The name of the API provider (e.g., 'openai', 'jina', 'serper') + provider: The provider name (e.g., 'openai', 'anthropic', 'google') Returns: - The API key as a string + The API key for the specified provider Raises: ValueError: If the API key is not found """ - # First check environment variables (higher priority) - env_var_name = f"{provider.upper()}_API_KEY" + provider = provider.lower() - # Special case for Jina AI which uses JINA_API_KEY - if provider.lower() == 'jina': - env_var_name = "JINA_API_KEY" + # Map provider names to environment variable names + provider_env_map = { + 'openai': 'OPENAI_API_KEY', + 'anthropic': 'ANTHROPIC_API_KEY', + 'google': 'GEMINI_API_KEY', + 'gemini': 'GEMINI_API_KEY', + 'vertex_ai': 'GOOGLE_APPLICATION_CREDENTIALS', + 'groq': 'GROQ_API_KEY', + 'openrouter': 'OPENROUTER_API_KEY', + 'serper': 'SERPER_API_KEY', + 'tavily': 'TAVILY_API_KEY', + 'perplexity': 'PERPLEXITY_API_KEY' + } - # Special case for Groq which might use GROQ_API_KEY - if provider.lower() == 'groq': - env_var_name = "GROQ_API_KEY" + # Get the environment variable name for the provider + env_var = provider_env_map.get(provider) + if not env_var: + env_var = f"{provider.upper()}_API_KEY" - # Special case for OpenRouter which might use OPENROUTER_API_KEY - if provider.lower() == 'openrouter': - env_var_name = "OPENROUTER_API_KEY" + # Try to get the API key from environment variables + api_key = os.environ.get(env_var) - # Special case for Google which might use GEMINI_API_KEY - if provider.lower() == 'google': - env_var_name = "GEMINI_API_KEY" - - api_key = os.environ.get(env_var_name) - - # If not in environment, check config file - if not api_key and self.config_data and 'api_keys' in self.config_data: + # If not found in environment, check the config file + if not api_key and 'api_keys' in self.config_data: api_key = self.config_data['api_keys'].get(provider) if not api_key: - raise ValueError(f"API key for {provider} not found. Set {env_var_name} environment variable or add to config file.") + raise ValueError(f"API key for {provider} not found. Please set the {env_var} environment variable or add it to the config file.") return api_key diff --git a/config/config.yaml b/config/config.yaml index da32dc1..d988fab 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -82,13 +82,12 @@ models: top_p: 1.0 endpoint: "https://openrouter.ai/api/v1" - gemini-2.0-flash-lite: - provider: "google" - model_name: "gemini-2.0-flash-lite" + gemini-2.0-flash: + provider: "gemini" + model_name: "gemini-2.0-flash" temperature: 0.5 max_tokens: 2048 top_p: 1.0 - endpoint: "https://generativelanguage.googleapis.com/v1beta" # Default model to use if not specified for a module default_model: "llama-3.1-8b-instant" # Using Groq's Llama 3.1 8B model for testing diff --git a/query/llm_interface.py b/query/llm_interface.py index e997d6e..7d818b8 100644 --- a/query/llm_interface.py +++ b/query/llm_interface.py @@ -46,8 +46,10 @@ class LLMInterface: api_key = self.config.get_api_key(provider) # Set environment variable for the provider - if provider.lower() == 'google': + if provider.lower() == 'google' or provider.lower() == 'gemini': os.environ["GEMINI_API_KEY"] = api_key + elif provider.lower() == 'vertex_ai': + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = api_key else: os.environ[f"{provider.upper()}_API_KEY"] = api_key @@ -102,14 +104,23 @@ class LLMInterface: 'HTTP-Referer': 'https://sim-search.app', # Replace with your actual app URL 'X-Title': 'Intelligent Research System' # Replace with your actual app name } - elif provider == 'google': + elif provider == 'google' or provider == 'gemini': # Special handling for Google Gemini models + # Format: gemini/model_name (e.g., gemini/gemini-2.0-flash) params['model'] = f"gemini/{self.model_config.get('model_name', self.model_name)}" - # Google Gemini uses a different API base - params['api_base'] = self.model_config.get('endpoint', 'https://generativelanguage.googleapis.com/v1beta') # Add additional parameters for Gemini params['custom_llm_provider'] = 'gemini' + elif provider == 'vertex_ai': + # Special handling for Vertex AI Gemini models + params['model'] = f"vertex_ai/{self.model_config.get('model_name', self.model_name)}" + + # Add Vertex AI specific parameters + params['vertex_project'] = self.model_config.get('vertex_project', 'sim-search') + params['vertex_location'] = self.model_config.get('vertex_location', 'us-central1') + + # Set custom provider + params['custom_llm_provider'] = 'vertex_ai' else: # Standard provider (OpenAI, Anthropic, etc.) params['model'] = self.model_name diff --git a/report/report_synthesis.py b/report/report_synthesis.py index 5b32aa1..5a11ebc 100644 --- a/report/report_synthesis.py +++ b/report/report_synthesis.py @@ -60,8 +60,10 @@ class ReportSynthesizer: api_key = self.config.get_api_key(provider) # Set environment variable for the provider - if provider.lower() == 'google': + if provider.lower() == 'google' or provider.lower() == 'gemini': os.environ["GEMINI_API_KEY"] = api_key + elif provider.lower() == 'vertex_ai': + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = api_key else: os.environ[f"{provider.upper()}_API_KEY"] = api_key @@ -98,14 +100,23 @@ class ReportSynthesizer: 'HTTP-Referer': 'https://sim-search.app', # Replace with your actual app URL 'X-Title': 'Intelligent Research System' # Replace with your actual app name } - elif provider == 'google': + elif provider == 'google' or provider == 'gemini': # Special handling for Google Gemini models + # Format: gemini/model_name (e.g., gemini/gemini-2.0-flash) params['model'] = f"gemini/{self.model_config.get('model_name', self.model_name)}" - # Google Gemini uses a different API base - params['api_base'] = self.model_config.get('endpoint', 'https://generativelanguage.googleapis.com/v1beta') # Add additional parameters for Gemini params['custom_llm_provider'] = 'gemini' + elif provider == 'vertex_ai': + # Special handling for Vertex AI Gemini models + params['model'] = f"vertex_ai/{self.model_config.get('model_name', self.model_name)}" + + # Add Vertex AI specific parameters + params['vertex_project'] = self.model_config.get('vertex_project', 'sim-search') + params['vertex_location'] = self.model_config.get('vertex_location', 'us-central1') + + # Set custom provider + params['custom_llm_provider'] = 'vertex_ai' else: # Standard provider (OpenAI, Anthropic, etc.) params['model'] = self.model_name