Update Gemini integration to use correct LiteLLM format

This commit is contained in:
Steve White 2025-02-28 15:25:33 -06:00
parent f94bde875b
commit d4beb73a7a
4 changed files with 59 additions and 36 deletions

View File

@ -47,45 +47,47 @@ class Config:
print(f"Error loading configuration: {e}")
def get_api_key(self, provider: str) -> str:
"""
Get API key for the specified provider.
"""Get the API key for a specific provider.
Args:
provider: The name of the API provider (e.g., 'openai', 'jina', 'serper')
provider: The provider name (e.g., 'openai', 'anthropic', 'google')
Returns:
The API key as a string
The API key for the specified provider
Raises:
ValueError: If the API key is not found
"""
# First check environment variables (higher priority)
env_var_name = f"{provider.upper()}_API_KEY"
provider = provider.lower()
# Special case for Jina AI which uses JINA_API_KEY
if provider.lower() == 'jina':
env_var_name = "JINA_API_KEY"
# Map provider names to environment variable names
provider_env_map = {
'openai': 'OPENAI_API_KEY',
'anthropic': 'ANTHROPIC_API_KEY',
'google': 'GEMINI_API_KEY',
'gemini': 'GEMINI_API_KEY',
'vertex_ai': 'GOOGLE_APPLICATION_CREDENTIALS',
'groq': 'GROQ_API_KEY',
'openrouter': 'OPENROUTER_API_KEY',
'serper': 'SERPER_API_KEY',
'tavily': 'TAVILY_API_KEY',
'perplexity': 'PERPLEXITY_API_KEY'
}
# Special case for Groq which might use GROQ_API_KEY
if provider.lower() == 'groq':
env_var_name = "GROQ_API_KEY"
# Get the environment variable name for the provider
env_var = provider_env_map.get(provider)
if not env_var:
env_var = f"{provider.upper()}_API_KEY"
# Special case for OpenRouter which might use OPENROUTER_API_KEY
if provider.lower() == 'openrouter':
env_var_name = "OPENROUTER_API_KEY"
# Try to get the API key from environment variables
api_key = os.environ.get(env_var)
# Special case for Google which might use GEMINI_API_KEY
if provider.lower() == 'google':
env_var_name = "GEMINI_API_KEY"
api_key = os.environ.get(env_var_name)
# If not in environment, check config file
if not api_key and self.config_data and 'api_keys' in self.config_data:
# If not found in environment, check the config file
if not api_key and 'api_keys' in self.config_data:
api_key = self.config_data['api_keys'].get(provider)
if not api_key:
raise ValueError(f"API key for {provider} not found. Set {env_var_name} environment variable or add to config file.")
raise ValueError(f"API key for {provider} not found. Please set the {env_var} environment variable or add it to the config file.")
return api_key

View File

@ -82,13 +82,12 @@ models:
top_p: 1.0
endpoint: "https://openrouter.ai/api/v1"
gemini-2.0-flash-lite:
provider: "google"
model_name: "gemini-2.0-flash-lite"
gemini-2.0-flash:
provider: "gemini"
model_name: "gemini-2.0-flash"
temperature: 0.5
max_tokens: 2048
top_p: 1.0
endpoint: "https://generativelanguage.googleapis.com/v1beta"
# Default model to use if not specified for a module
default_model: "llama-3.1-8b-instant" # Using Groq's Llama 3.1 8B model for testing

View File

@ -46,8 +46,10 @@ class LLMInterface:
api_key = self.config.get_api_key(provider)
# Set environment variable for the provider
if provider.lower() == 'google':
if provider.lower() == 'google' or provider.lower() == 'gemini':
os.environ["GEMINI_API_KEY"] = api_key
elif provider.lower() == 'vertex_ai':
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = api_key
else:
os.environ[f"{provider.upper()}_API_KEY"] = api_key
@ -102,14 +104,23 @@ class LLMInterface:
'HTTP-Referer': 'https://sim-search.app', # Replace with your actual app URL
'X-Title': 'Intelligent Research System' # Replace with your actual app name
}
elif provider == 'google':
elif provider == 'google' or provider == 'gemini':
# Special handling for Google Gemini models
# Format: gemini/model_name (e.g., gemini/gemini-2.0-flash)
params['model'] = f"gemini/{self.model_config.get('model_name', self.model_name)}"
# Google Gemini uses a different API base
params['api_base'] = self.model_config.get('endpoint', 'https://generativelanguage.googleapis.com/v1beta')
# Add additional parameters for Gemini
params['custom_llm_provider'] = 'gemini'
elif provider == 'vertex_ai':
# Special handling for Vertex AI Gemini models
params['model'] = f"vertex_ai/{self.model_config.get('model_name', self.model_name)}"
# Add Vertex AI specific parameters
params['vertex_project'] = self.model_config.get('vertex_project', 'sim-search')
params['vertex_location'] = self.model_config.get('vertex_location', 'us-central1')
# Set custom provider
params['custom_llm_provider'] = 'vertex_ai'
else:
# Standard provider (OpenAI, Anthropic, etc.)
params['model'] = self.model_name

View File

@ -60,8 +60,10 @@ class ReportSynthesizer:
api_key = self.config.get_api_key(provider)
# Set environment variable for the provider
if provider.lower() == 'google':
if provider.lower() == 'google' or provider.lower() == 'gemini':
os.environ["GEMINI_API_KEY"] = api_key
elif provider.lower() == 'vertex_ai':
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = api_key
else:
os.environ[f"{provider.upper()}_API_KEY"] = api_key
@ -98,14 +100,23 @@ class ReportSynthesizer:
'HTTP-Referer': 'https://sim-search.app', # Replace with your actual app URL
'X-Title': 'Intelligent Research System' # Replace with your actual app name
}
elif provider == 'google':
elif provider == 'google' or provider == 'gemini':
# Special handling for Google Gemini models
# Format: gemini/model_name (e.g., gemini/gemini-2.0-flash)
params['model'] = f"gemini/{self.model_config.get('model_name', self.model_name)}"
# Google Gemini uses a different API base
params['api_base'] = self.model_config.get('endpoint', 'https://generativelanguage.googleapis.com/v1beta')
# Add additional parameters for Gemini
params['custom_llm_provider'] = 'gemini'
elif provider == 'vertex_ai':
# Special handling for Vertex AI Gemini models
params['model'] = f"vertex_ai/{self.model_config.get('model_name', self.model_name)}"
# Add Vertex AI specific parameters
params['vertex_project'] = self.model_config.get('vertex_project', 'sim-search')
params['vertex_location'] = self.model_config.get('vertex_location', 'us-central1')
# Set custom provider
params['custom_llm_provider'] = 'vertex_ai'
else:
# Standard provider (OpenAI, Anthropic, etc.)
params['model'] = self.model_name