Update Gemini integration to use correct LiteLLM format
This commit is contained in:
parent
f94bde875b
commit
d4beb73a7a
|
@ -47,45 +47,47 @@ class Config:
|
|||
print(f"Error loading configuration: {e}")
|
||||
|
||||
def get_api_key(self, provider: str) -> str:
|
||||
"""
|
||||
Get API key for the specified provider.
|
||||
"""Get the API key for a specific provider.
|
||||
|
||||
Args:
|
||||
provider: The name of the API provider (e.g., 'openai', 'jina', 'serper')
|
||||
provider: The provider name (e.g., 'openai', 'anthropic', 'google')
|
||||
|
||||
Returns:
|
||||
The API key as a string
|
||||
The API key for the specified provider
|
||||
|
||||
Raises:
|
||||
ValueError: If the API key is not found
|
||||
"""
|
||||
# First check environment variables (higher priority)
|
||||
env_var_name = f"{provider.upper()}_API_KEY"
|
||||
provider = provider.lower()
|
||||
|
||||
# Special case for Jina AI which uses JINA_API_KEY
|
||||
if provider.lower() == 'jina':
|
||||
env_var_name = "JINA_API_KEY"
|
||||
# Map provider names to environment variable names
|
||||
provider_env_map = {
|
||||
'openai': 'OPENAI_API_KEY',
|
||||
'anthropic': 'ANTHROPIC_API_KEY',
|
||||
'google': 'GEMINI_API_KEY',
|
||||
'gemini': 'GEMINI_API_KEY',
|
||||
'vertex_ai': 'GOOGLE_APPLICATION_CREDENTIALS',
|
||||
'groq': 'GROQ_API_KEY',
|
||||
'openrouter': 'OPENROUTER_API_KEY',
|
||||
'serper': 'SERPER_API_KEY',
|
||||
'tavily': 'TAVILY_API_KEY',
|
||||
'perplexity': 'PERPLEXITY_API_KEY'
|
||||
}
|
||||
|
||||
# Special case for Groq which might use GROQ_API_KEY
|
||||
if provider.lower() == 'groq':
|
||||
env_var_name = "GROQ_API_KEY"
|
||||
# Get the environment variable name for the provider
|
||||
env_var = provider_env_map.get(provider)
|
||||
if not env_var:
|
||||
env_var = f"{provider.upper()}_API_KEY"
|
||||
|
||||
# Special case for OpenRouter which might use OPENROUTER_API_KEY
|
||||
if provider.lower() == 'openrouter':
|
||||
env_var_name = "OPENROUTER_API_KEY"
|
||||
# Try to get the API key from environment variables
|
||||
api_key = os.environ.get(env_var)
|
||||
|
||||
# Special case for Google which might use GEMINI_API_KEY
|
||||
if provider.lower() == 'google':
|
||||
env_var_name = "GEMINI_API_KEY"
|
||||
|
||||
api_key = os.environ.get(env_var_name)
|
||||
|
||||
# If not in environment, check config file
|
||||
if not api_key and self.config_data and 'api_keys' in self.config_data:
|
||||
# If not found in environment, check the config file
|
||||
if not api_key and 'api_keys' in self.config_data:
|
||||
api_key = self.config_data['api_keys'].get(provider)
|
||||
|
||||
if not api_key:
|
||||
raise ValueError(f"API key for {provider} not found. Set {env_var_name} environment variable or add to config file.")
|
||||
raise ValueError(f"API key for {provider} not found. Please set the {env_var} environment variable or add it to the config file.")
|
||||
|
||||
return api_key
|
||||
|
||||
|
|
|
@ -82,13 +82,12 @@ models:
|
|||
top_p: 1.0
|
||||
endpoint: "https://openrouter.ai/api/v1"
|
||||
|
||||
gemini-2.0-flash-lite:
|
||||
provider: "google"
|
||||
model_name: "gemini-2.0-flash-lite"
|
||||
gemini-2.0-flash:
|
||||
provider: "gemini"
|
||||
model_name: "gemini-2.0-flash"
|
||||
temperature: 0.5
|
||||
max_tokens: 2048
|
||||
top_p: 1.0
|
||||
endpoint: "https://generativelanguage.googleapis.com/v1beta"
|
||||
|
||||
# Default model to use if not specified for a module
|
||||
default_model: "llama-3.1-8b-instant" # Using Groq's Llama 3.1 8B model for testing
|
||||
|
|
|
@ -46,8 +46,10 @@ class LLMInterface:
|
|||
api_key = self.config.get_api_key(provider)
|
||||
|
||||
# Set environment variable for the provider
|
||||
if provider.lower() == 'google':
|
||||
if provider.lower() == 'google' or provider.lower() == 'gemini':
|
||||
os.environ["GEMINI_API_KEY"] = api_key
|
||||
elif provider.lower() == 'vertex_ai':
|
||||
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = api_key
|
||||
else:
|
||||
os.environ[f"{provider.upper()}_API_KEY"] = api_key
|
||||
|
||||
|
@ -102,14 +104,23 @@ class LLMInterface:
|
|||
'HTTP-Referer': 'https://sim-search.app', # Replace with your actual app URL
|
||||
'X-Title': 'Intelligent Research System' # Replace with your actual app name
|
||||
}
|
||||
elif provider == 'google':
|
||||
elif provider == 'google' or provider == 'gemini':
|
||||
# Special handling for Google Gemini models
|
||||
# Format: gemini/model_name (e.g., gemini/gemini-2.0-flash)
|
||||
params['model'] = f"gemini/{self.model_config.get('model_name', self.model_name)}"
|
||||
# Google Gemini uses a different API base
|
||||
params['api_base'] = self.model_config.get('endpoint', 'https://generativelanguage.googleapis.com/v1beta')
|
||||
|
||||
# Add additional parameters for Gemini
|
||||
params['custom_llm_provider'] = 'gemini'
|
||||
elif provider == 'vertex_ai':
|
||||
# Special handling for Vertex AI Gemini models
|
||||
params['model'] = f"vertex_ai/{self.model_config.get('model_name', self.model_name)}"
|
||||
|
||||
# Add Vertex AI specific parameters
|
||||
params['vertex_project'] = self.model_config.get('vertex_project', 'sim-search')
|
||||
params['vertex_location'] = self.model_config.get('vertex_location', 'us-central1')
|
||||
|
||||
# Set custom provider
|
||||
params['custom_llm_provider'] = 'vertex_ai'
|
||||
else:
|
||||
# Standard provider (OpenAI, Anthropic, etc.)
|
||||
params['model'] = self.model_name
|
||||
|
|
|
@ -60,8 +60,10 @@ class ReportSynthesizer:
|
|||
api_key = self.config.get_api_key(provider)
|
||||
|
||||
# Set environment variable for the provider
|
||||
if provider.lower() == 'google':
|
||||
if provider.lower() == 'google' or provider.lower() == 'gemini':
|
||||
os.environ["GEMINI_API_KEY"] = api_key
|
||||
elif provider.lower() == 'vertex_ai':
|
||||
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = api_key
|
||||
else:
|
||||
os.environ[f"{provider.upper()}_API_KEY"] = api_key
|
||||
|
||||
|
@ -98,14 +100,23 @@ class ReportSynthesizer:
|
|||
'HTTP-Referer': 'https://sim-search.app', # Replace with your actual app URL
|
||||
'X-Title': 'Intelligent Research System' # Replace with your actual app name
|
||||
}
|
||||
elif provider == 'google':
|
||||
elif provider == 'google' or provider == 'gemini':
|
||||
# Special handling for Google Gemini models
|
||||
# Format: gemini/model_name (e.g., gemini/gemini-2.0-flash)
|
||||
params['model'] = f"gemini/{self.model_config.get('model_name', self.model_name)}"
|
||||
# Google Gemini uses a different API base
|
||||
params['api_base'] = self.model_config.get('endpoint', 'https://generativelanguage.googleapis.com/v1beta')
|
||||
|
||||
# Add additional parameters for Gemini
|
||||
params['custom_llm_provider'] = 'gemini'
|
||||
elif provider == 'vertex_ai':
|
||||
# Special handling for Vertex AI Gemini models
|
||||
params['model'] = f"vertex_ai/{self.model_config.get('model_name', self.model_name)}"
|
||||
|
||||
# Add Vertex AI specific parameters
|
||||
params['vertex_project'] = self.model_config.get('vertex_project', 'sim-search')
|
||||
params['vertex_location'] = self.model_config.get('vertex_location', 'us-central1')
|
||||
|
||||
# Set custom provider
|
||||
params['custom_llm_provider'] = 'vertex_ai'
|
||||
else:
|
||||
# Standard provider (OpenAI, Anthropic, etc.)
|
||||
params['model'] = self.model_name
|
||||
|
|
Loading…
Reference in New Issue