Update Gemini integration to use correct LiteLLM format

This commit is contained in:
Steve White 2025-02-28 15:25:33 -06:00
parent f94bde875b
commit d4beb73a7a
4 changed files with 59 additions and 36 deletions

View File

@ -47,45 +47,47 @@ class Config:
print(f"Error loading configuration: {e}") print(f"Error loading configuration: {e}")
def get_api_key(self, provider: str) -> str: def get_api_key(self, provider: str) -> str:
""" """Get the API key for a specific provider.
Get API key for the specified provider.
Args: Args:
provider: The name of the API provider (e.g., 'openai', 'jina', 'serper') provider: The provider name (e.g., 'openai', 'anthropic', 'google')
Returns: Returns:
The API key as a string The API key for the specified provider
Raises: Raises:
ValueError: If the API key is not found ValueError: If the API key is not found
""" """
# First check environment variables (higher priority) provider = provider.lower()
env_var_name = f"{provider.upper()}_API_KEY"
# Special case for Jina AI which uses JINA_API_KEY # Map provider names to environment variable names
if provider.lower() == 'jina': provider_env_map = {
env_var_name = "JINA_API_KEY" 'openai': 'OPENAI_API_KEY',
'anthropic': 'ANTHROPIC_API_KEY',
'google': 'GEMINI_API_KEY',
'gemini': 'GEMINI_API_KEY',
'vertex_ai': 'GOOGLE_APPLICATION_CREDENTIALS',
'groq': 'GROQ_API_KEY',
'openrouter': 'OPENROUTER_API_KEY',
'serper': 'SERPER_API_KEY',
'tavily': 'TAVILY_API_KEY',
'perplexity': 'PERPLEXITY_API_KEY'
}
# Special case for Groq which might use GROQ_API_KEY # Get the environment variable name for the provider
if provider.lower() == 'groq': env_var = provider_env_map.get(provider)
env_var_name = "GROQ_API_KEY" if not env_var:
env_var = f"{provider.upper()}_API_KEY"
# Special case for OpenRouter which might use OPENROUTER_API_KEY # Try to get the API key from environment variables
if provider.lower() == 'openrouter': api_key = os.environ.get(env_var)
env_var_name = "OPENROUTER_API_KEY"
# Special case for Google which might use GEMINI_API_KEY # If not found in environment, check the config file
if provider.lower() == 'google': if not api_key and 'api_keys' in self.config_data:
env_var_name = "GEMINI_API_KEY"
api_key = os.environ.get(env_var_name)
# If not in environment, check config file
if not api_key and self.config_data and 'api_keys' in self.config_data:
api_key = self.config_data['api_keys'].get(provider) api_key = self.config_data['api_keys'].get(provider)
if not api_key: if not api_key:
raise ValueError(f"API key for {provider} not found. Set {env_var_name} environment variable or add to config file.") raise ValueError(f"API key for {provider} not found. Please set the {env_var} environment variable or add it to the config file.")
return api_key return api_key

View File

@ -82,13 +82,12 @@ models:
top_p: 1.0 top_p: 1.0
endpoint: "https://openrouter.ai/api/v1" endpoint: "https://openrouter.ai/api/v1"
gemini-2.0-flash-lite: gemini-2.0-flash:
provider: "google" provider: "gemini"
model_name: "gemini-2.0-flash-lite" model_name: "gemini-2.0-flash"
temperature: 0.5 temperature: 0.5
max_tokens: 2048 max_tokens: 2048
top_p: 1.0 top_p: 1.0
endpoint: "https://generativelanguage.googleapis.com/v1beta"
# Default model to use if not specified for a module # Default model to use if not specified for a module
default_model: "llama-3.1-8b-instant" # Using Groq's Llama 3.1 8B model for testing default_model: "llama-3.1-8b-instant" # Using Groq's Llama 3.1 8B model for testing

View File

@ -46,8 +46,10 @@ class LLMInterface:
api_key = self.config.get_api_key(provider) api_key = self.config.get_api_key(provider)
# Set environment variable for the provider # Set environment variable for the provider
if provider.lower() == 'google': if provider.lower() == 'google' or provider.lower() == 'gemini':
os.environ["GEMINI_API_KEY"] = api_key os.environ["GEMINI_API_KEY"] = api_key
elif provider.lower() == 'vertex_ai':
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = api_key
else: else:
os.environ[f"{provider.upper()}_API_KEY"] = api_key os.environ[f"{provider.upper()}_API_KEY"] = api_key
@ -102,14 +104,23 @@ class LLMInterface:
'HTTP-Referer': 'https://sim-search.app', # Replace with your actual app URL 'HTTP-Referer': 'https://sim-search.app', # Replace with your actual app URL
'X-Title': 'Intelligent Research System' # Replace with your actual app name 'X-Title': 'Intelligent Research System' # Replace with your actual app name
} }
elif provider == 'google': elif provider == 'google' or provider == 'gemini':
# Special handling for Google Gemini models # Special handling for Google Gemini models
# Format: gemini/model_name (e.g., gemini/gemini-2.0-flash)
params['model'] = f"gemini/{self.model_config.get('model_name', self.model_name)}" params['model'] = f"gemini/{self.model_config.get('model_name', self.model_name)}"
# Google Gemini uses a different API base
params['api_base'] = self.model_config.get('endpoint', 'https://generativelanguage.googleapis.com/v1beta')
# Add additional parameters for Gemini # Add additional parameters for Gemini
params['custom_llm_provider'] = 'gemini' params['custom_llm_provider'] = 'gemini'
elif provider == 'vertex_ai':
# Special handling for Vertex AI Gemini models
params['model'] = f"vertex_ai/{self.model_config.get('model_name', self.model_name)}"
# Add Vertex AI specific parameters
params['vertex_project'] = self.model_config.get('vertex_project', 'sim-search')
params['vertex_location'] = self.model_config.get('vertex_location', 'us-central1')
# Set custom provider
params['custom_llm_provider'] = 'vertex_ai'
else: else:
# Standard provider (OpenAI, Anthropic, etc.) # Standard provider (OpenAI, Anthropic, etc.)
params['model'] = self.model_name params['model'] = self.model_name

View File

@ -60,8 +60,10 @@ class ReportSynthesizer:
api_key = self.config.get_api_key(provider) api_key = self.config.get_api_key(provider)
# Set environment variable for the provider # Set environment variable for the provider
if provider.lower() == 'google': if provider.lower() == 'google' or provider.lower() == 'gemini':
os.environ["GEMINI_API_KEY"] = api_key os.environ["GEMINI_API_KEY"] = api_key
elif provider.lower() == 'vertex_ai':
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = api_key
else: else:
os.environ[f"{provider.upper()}_API_KEY"] = api_key os.environ[f"{provider.upper()}_API_KEY"] = api_key
@ -98,14 +100,23 @@ class ReportSynthesizer:
'HTTP-Referer': 'https://sim-search.app', # Replace with your actual app URL 'HTTP-Referer': 'https://sim-search.app', # Replace with your actual app URL
'X-Title': 'Intelligent Research System' # Replace with your actual app name 'X-Title': 'Intelligent Research System' # Replace with your actual app name
} }
elif provider == 'google': elif provider == 'google' or provider == 'gemini':
# Special handling for Google Gemini models # Special handling for Google Gemini models
# Format: gemini/model_name (e.g., gemini/gemini-2.0-flash)
params['model'] = f"gemini/{self.model_config.get('model_name', self.model_name)}" params['model'] = f"gemini/{self.model_config.get('model_name', self.model_name)}"
# Google Gemini uses a different API base
params['api_base'] = self.model_config.get('endpoint', 'https://generativelanguage.googleapis.com/v1beta')
# Add additional parameters for Gemini # Add additional parameters for Gemini
params['custom_llm_provider'] = 'gemini' params['custom_llm_provider'] = 'gemini'
elif provider == 'vertex_ai':
# Special handling for Vertex AI Gemini models
params['model'] = f"vertex_ai/{self.model_config.get('model_name', self.model_name)}"
# Add Vertex AI specific parameters
params['vertex_project'] = self.model_config.get('vertex_project', 'sim-search')
params['vertex_location'] = self.model_config.get('vertex_location', 'us-central1')
# Set custom provider
params['custom_llm_provider'] = 'vertex_ai'
else: else:
# Standard provider (OpenAI, Anthropic, etc.) # Standard provider (OpenAI, Anthropic, etc.)
params['model'] = self.model_name params['model'] = self.model_name