Add Google Gemini model support and dynamic model selection in UI

This commit is contained in:
Steve White 2025-02-28 11:31:53 -06:00
parent 76cedb9528
commit f6f660da97
6 changed files with 283 additions and 44 deletions

View File

@ -74,6 +74,10 @@ class Config:
if provider.lower() == 'openrouter':
env_var_name = "OPENROUTER_API_KEY"
# Special case for Google which might use GOOGLE_API_KEY
if provider.lower() == 'google':
env_var_name = "GOOGLE_API_KEY"
api_key = os.environ.get(env_var_name)
# If not in environment, check config file

158
config/config.yaml Normal file
View File

@ -0,0 +1,158 @@
# Example configuration file for the intelligent research system
# Rename this file to config.yaml and fill in your API keys and settings
# API keys (alternatively, set environment variables)
api_keys:
openai: "your-openai-api-key" # Or set OPENAI_API_KEY environment variable
jina: "your-jina-api-key" # Or set JINA_API_KEY environment variable
serper: "your-serper-api-key" # Or set SERPER_API_KEY environment variable
google: "your-google-api-key" # Or set GOOGLE_API_KEY environment variable
anthropic: "your-anthropic-api-key" # Or set ANTHROPIC_API_KEY environment variable
openrouter: "your-openrouter-api-key" # Or set OPENROUTER_API_KEY environment variable
groq: "your-groq-api-key" # Or set GROQ_API_KEY environment variable
# LLM model configurations
models:
gpt-3.5-turbo:
provider: "openai"
temperature: 0.7
max_tokens: 1000
top_p: 1.0
endpoint: null # Use default OpenAI endpoint
gpt-4:
provider: "openai"
temperature: 0.5
max_tokens: 2000
top_p: 1.0
endpoint: null # Use default OpenAI endpoint
claude-2:
provider: "anthropic"
temperature: 0.7
max_tokens: 1500
top_p: 1.0
endpoint: null # Use default Anthropic endpoint
azure-gpt-4:
provider: "azure"
temperature: 0.5
max_tokens: 2000
top_p: 1.0
endpoint: "https://your-azure-endpoint.openai.azure.com"
deployment_name: "your-deployment-name"
api_version: "2023-05-15"
local-llama:
provider: "ollama"
temperature: 0.8
max_tokens: 1000
endpoint: "http://localhost:11434/api/generate"
model_name: "llama2"
llama-3.1-8b-instant:
provider: "groq"
model_name: "llama-3.1-8b-instant"
temperature: 0.7
max_tokens: 1024
top_p: 1.0
endpoint: "https://api.groq.com/openai/v1"
llama-3.3-70b-versatile:
provider: "groq"
model_name: "llama-3.3-70b-versatile"
temperature: 0.5
max_tokens: 2048
top_p: 1.0
endpoint: "https://api.groq.com/openai/v1"
openrouter-mixtral:
provider: "openrouter"
model_name: "mistralai/mixtral-8x7b-instruct"
temperature: 0.7
max_tokens: 1024
top_p: 1.0
endpoint: "https://openrouter.ai/api/v1"
openrouter-claude:
provider: "openrouter"
model_name: "anthropic/claude-3-opus"
temperature: 0.5
max_tokens: 2048
top_p: 1.0
endpoint: "https://openrouter.ai/api/v1"
gemini-2.0-flash-lite:
provider: "google"
model_name: "google/gemini-2.0-flash-lite-001"
temperature: 0.5
max_tokens: 2048
top_p: 1.0
endpoint: "https://generativelanguage.googleapis.com/v1"
# Default model to use if not specified for a module
default_model: "llama-3.1-8b-instant" # Using Groq's Llama 3.1 8B model for testing
# Module-specific model assignments
module_models:
# Query processing module
query_processing:
enhance_query: "llama-3.1-8b-instant" # Use Groq's Llama 3.1 8B for query enhancement
classify_query: "llama-3.1-8b-instant" # Use Groq's Llama 3.1 8B for classification
generate_search_queries: "llama-3.1-8b-instant" # Use Groq's Llama 3.1 8B for generating search queries
# Search strategy module
search_strategy:
develop_strategy: "llama-3.1-8b-instant" # Use Groq's Llama 3.1 8B for developing search strategies
target_selection: "llama-3.1-8b-instant" # Use Groq's Llama 3.1 8B for target selection
# Document ranking module
document_ranking:
rerank_documents: "jina-reranker" # Use Jina's reranker for document reranking
# Report generation module
report_generation:
synthesize_report: "llama-3.3-70b-versatile" # Use Groq's Llama 3.3 70B for report synthesis
format_report: "llama-3.1-8b-instant" # Use Groq's Llama 3.1 8B for formatting
# Search engine configurations
search_engines:
google:
enabled: true
max_results: 10
serper:
enabled: true
max_results: 10
jina:
enabled: true
max_results: 10
scholar:
enabled: false
max_results: 5
arxiv:
enabled: false
max_results: 5
# Jina AI specific configurations
jina:
reranker:
model: "jina-reranker-v2-base-multilingual" # Default reranker model
top_n: 10 # Default number of top results to return
# UI configuration
ui:
theme: "light" # light or dark
port: 7860
share: false
title: "Intelligent Research System"
description: "An automated system for finding, filtering, and synthesizing information"
# System settings
system:
cache_dir: "data/cache"
results_dir: "data/results"
log_level: "INFO" # DEBUG, INFO, WARNING, ERROR, CRITICAL

View File

@ -99,6 +99,11 @@ class LLMInterface:
'HTTP-Referer': 'https://sim-search.app', # Replace with your actual app URL
'X-Title': 'Intelligent Research System' # Replace with your actual app name
}
elif provider == 'google':
# Special handling for Google Gemini models
params['model'] = self.model_config.get('model_name', self.model_name)
# Google Gemini uses a different API base
params['api_base'] = self.model_config.get('endpoint', 'https://generativelanguage.googleapis.com/v1')
else:
# Standard provider (OpenAI, Anthropic, etc.)
params['model'] = self.model_name

View File

@ -56,11 +56,11 @@ class ReportDetailLevelManager:
"description": "A comprehensive report with in-depth analysis, methodology, and implications."
},
DetailLevel.COMPREHENSIVE: {
"num_results": 8,
"token_budget": 80000,
"chunk_size": 800,
"overlap_size": 80,
"model": "llama-3.3-70b-versatile",
"num_results": 12,
"token_budget": 200000,
"chunk_size": 1200,
"overlap_size": 120,
"model": "gemini-2.0-flash-lite",
"description": "An exhaustive report with all available information, extensive analysis, and detailed references."
}
}

View File

@ -180,9 +180,23 @@ class ReportSynthesizer:
total_chunks = len(chunks)
logger.info(f"Starting to process {total_chunks} document chunks")
for i, chunk in enumerate(chunks, 1):
# Determine batch size based on the model - Gemini can handle larger batches
if "gemini" in self.model_name.lower():
batch_size = 8 # Larger batch size for Gemini models with 1M token windows
else:
batch_size = 3 # Smaller batch size for other models
logger.info(f"Using batch size of {batch_size} for model {self.model_name}")
for i in range(0, len(chunks), batch_size):
batch = chunks[i:i+batch_size]
logger.info(f"Processing batch {i//batch_size + 1}/{(len(chunks) + batch_size - 1)//batch_size} with {len(batch)} chunks")
# Process this batch
batch_results = []
for chunk in batch:
chunk_title = chunk.get('title', 'Untitled')
logger.info(f"Processing chunk {i}/{total_chunks}: {chunk_title[:50]}...")
logger.info(f"Processing chunk {i+1}/{total_chunks}: {chunk_title[:50]}...")
# Create a prompt for extracting key information from the chunk
messages = [
@ -205,15 +219,22 @@ class ReportSynthesizer:
# Add the extracted information to the chunk
processed_chunk = chunk.copy()
processed_chunk['extracted_info'] = extracted_info
processed_chunks.append(processed_chunk)
batch_results.append(processed_chunk)
logger.info(f"Completed chunk {i}/{total_chunks} ({(i/total_chunks)*100:.1f}% complete)")
logger.info(f"Completed chunk {i+1}/{total_chunks} ({(i+1)/total_chunks*100:.1f}% complete)")
except Exception as e:
logger.error(f"Error processing chunk {i}/{total_chunks}: {str(e)}")
logger.error(f"Error processing chunk {i+1}/{total_chunks}: {str(e)}")
# Add a placeholder for the failed chunk to maintain document order
processed_chunk = chunk.copy()
processed_chunk['extracted_info'] = f"Error extracting information: {str(e)}"
processed_chunks.append(processed_chunk)
batch_results.append(processed_chunk)
processed_chunks.extend(batch_results)
# Add a small delay between batches to avoid rate limiting
if i + batch_size < len(chunks):
logger.info("Pausing briefly between batches...")
await asyncio.sleep(2)
logger.info(f"Completed processing all {total_chunks} chunks")
return processed_chunks
@ -355,7 +376,13 @@ class ReportSynthesizer:
logger.info(f"Starting map phase for {len(chunks)} document chunks with query type '{query_type}' and detail level '{detail_level}'")
# Process chunks in batches to avoid hitting payload limits
batch_size = 3 # Process 3 chunks at a time
# Determine batch size based on the model - Gemini can handle larger batches
if "gemini" in self.model_name.lower():
batch_size = 8 # Larger batch size for Gemini models with 1M token windows
else:
batch_size = 3 # Smaller batch size for other models
logger.info(f"Using batch size of {batch_size} for model {self.model_name}")
processed_chunks = []
for i in range(0, len(chunks), batch_size):

View File

@ -202,7 +202,13 @@ class GradioInterface:
# Create a timestamped output file
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_suffix = ""
# Extract the actual model name from the description if selected
if custom_model:
# If the model is in the format "model_name (provider: model_display)"
if "(" in custom_model:
custom_model = custom_model.split(" (")[0]
model_name = custom_model.split('/')[-1]
model_suffix = f"_{model_name}"
@ -315,15 +321,53 @@ class GradioInterface:
list: List of available model names
"""
# Get models from config
models = []
# Extract all model names from the config file
if 'models' in self.config.config_data:
models = list(self.config.config_data['models'].keys())
# If no models found, provide some defaults
if not models:
models = [
"llama-3.1-8b-instant",
"llama-3.3-70b-versatile",
"groq/deepseek-r1-distill-llama-70b-specdec",
"openrouter-mixtral",
"openrouter-claude"
"openrouter-claude",
"gemini-2.0-flash-lite"
]
return models
def get_model_descriptions(self):
"""
Get descriptions for available models.
Returns:
dict: Dictionary mapping model names to descriptions
"""
descriptions = {}
model_name_to_description = {}
if 'models' in self.config.config_data:
for model_name, model_config in self.config.config_data['models'].items():
provider = model_config.get('provider', 'unknown')
model_display = model_config.get('model_name', model_name)
max_tokens = model_config.get('max_tokens', 'unknown')
temperature = model_config.get('temperature', 'unknown')
# Create a description that includes the provider and actual model name
display_name = f"{model_name} ({provider}: {model_display})"
descriptions[model_name] = display_name
# Create a more detailed description for the dropdown tooltip
detailed_info = f"{display_name} - Max tokens: {max_tokens}, Temperature: {temperature}"
model_name_to_description[display_name] = detailed_info
self.model_name_to_description = model_name_to_description
return descriptions
def create_interface(self):
"""
Create and return the Gradio interface.
@ -401,8 +445,9 @@ class GradioInterface:
label="Detail Level",
info="Controls the depth and breadth of the report"
)
model_descriptions = self.get_model_descriptions()
report_custom_model = gr.Dropdown(
choices=self.get_available_models(),
choices=list(self.model_name_to_description.keys()),
value=None,
label="Custom Model (Optional)",
info="Select a custom model for report generation"