Add Google Gemini model support and dynamic model selection in UI
This commit is contained in:
parent
76cedb9528
commit
f6f660da97
|
@ -74,6 +74,10 @@ class Config:
|
|||
if provider.lower() == 'openrouter':
|
||||
env_var_name = "OPENROUTER_API_KEY"
|
||||
|
||||
# Special case for Google which might use GOOGLE_API_KEY
|
||||
if provider.lower() == 'google':
|
||||
env_var_name = "GOOGLE_API_KEY"
|
||||
|
||||
api_key = os.environ.get(env_var_name)
|
||||
|
||||
# If not in environment, check config file
|
||||
|
|
|
@ -0,0 +1,158 @@
|
|||
# Example configuration file for the intelligent research system
|
||||
# Rename this file to config.yaml and fill in your API keys and settings
|
||||
|
||||
# API keys (alternatively, set environment variables)
|
||||
api_keys:
|
||||
openai: "your-openai-api-key" # Or set OPENAI_API_KEY environment variable
|
||||
jina: "your-jina-api-key" # Or set JINA_API_KEY environment variable
|
||||
serper: "your-serper-api-key" # Or set SERPER_API_KEY environment variable
|
||||
google: "your-google-api-key" # Or set GOOGLE_API_KEY environment variable
|
||||
anthropic: "your-anthropic-api-key" # Or set ANTHROPIC_API_KEY environment variable
|
||||
openrouter: "your-openrouter-api-key" # Or set OPENROUTER_API_KEY environment variable
|
||||
groq: "your-groq-api-key" # Or set GROQ_API_KEY environment variable
|
||||
|
||||
# LLM model configurations
|
||||
models:
|
||||
gpt-3.5-turbo:
|
||||
provider: "openai"
|
||||
temperature: 0.7
|
||||
max_tokens: 1000
|
||||
top_p: 1.0
|
||||
endpoint: null # Use default OpenAI endpoint
|
||||
|
||||
gpt-4:
|
||||
provider: "openai"
|
||||
temperature: 0.5
|
||||
max_tokens: 2000
|
||||
top_p: 1.0
|
||||
endpoint: null # Use default OpenAI endpoint
|
||||
|
||||
claude-2:
|
||||
provider: "anthropic"
|
||||
temperature: 0.7
|
||||
max_tokens: 1500
|
||||
top_p: 1.0
|
||||
endpoint: null # Use default Anthropic endpoint
|
||||
|
||||
azure-gpt-4:
|
||||
provider: "azure"
|
||||
temperature: 0.5
|
||||
max_tokens: 2000
|
||||
top_p: 1.0
|
||||
endpoint: "https://your-azure-endpoint.openai.azure.com"
|
||||
deployment_name: "your-deployment-name"
|
||||
api_version: "2023-05-15"
|
||||
|
||||
local-llama:
|
||||
provider: "ollama"
|
||||
temperature: 0.8
|
||||
max_tokens: 1000
|
||||
endpoint: "http://localhost:11434/api/generate"
|
||||
model_name: "llama2"
|
||||
|
||||
llama-3.1-8b-instant:
|
||||
provider: "groq"
|
||||
model_name: "llama-3.1-8b-instant"
|
||||
temperature: 0.7
|
||||
max_tokens: 1024
|
||||
top_p: 1.0
|
||||
endpoint: "https://api.groq.com/openai/v1"
|
||||
|
||||
llama-3.3-70b-versatile:
|
||||
provider: "groq"
|
||||
model_name: "llama-3.3-70b-versatile"
|
||||
temperature: 0.5
|
||||
max_tokens: 2048
|
||||
top_p: 1.0
|
||||
endpoint: "https://api.groq.com/openai/v1"
|
||||
|
||||
openrouter-mixtral:
|
||||
provider: "openrouter"
|
||||
model_name: "mistralai/mixtral-8x7b-instruct"
|
||||
temperature: 0.7
|
||||
max_tokens: 1024
|
||||
top_p: 1.0
|
||||
endpoint: "https://openrouter.ai/api/v1"
|
||||
|
||||
openrouter-claude:
|
||||
provider: "openrouter"
|
||||
model_name: "anthropic/claude-3-opus"
|
||||
temperature: 0.5
|
||||
max_tokens: 2048
|
||||
top_p: 1.0
|
||||
endpoint: "https://openrouter.ai/api/v1"
|
||||
|
||||
gemini-2.0-flash-lite:
|
||||
provider: "google"
|
||||
model_name: "google/gemini-2.0-flash-lite-001"
|
||||
temperature: 0.5
|
||||
max_tokens: 2048
|
||||
top_p: 1.0
|
||||
endpoint: "https://generativelanguage.googleapis.com/v1"
|
||||
|
||||
# Default model to use if not specified for a module
|
||||
default_model: "llama-3.1-8b-instant" # Using Groq's Llama 3.1 8B model for testing
|
||||
|
||||
# Module-specific model assignments
|
||||
module_models:
|
||||
# Query processing module
|
||||
query_processing:
|
||||
enhance_query: "llama-3.1-8b-instant" # Use Groq's Llama 3.1 8B for query enhancement
|
||||
classify_query: "llama-3.1-8b-instant" # Use Groq's Llama 3.1 8B for classification
|
||||
generate_search_queries: "llama-3.1-8b-instant" # Use Groq's Llama 3.1 8B for generating search queries
|
||||
|
||||
# Search strategy module
|
||||
search_strategy:
|
||||
develop_strategy: "llama-3.1-8b-instant" # Use Groq's Llama 3.1 8B for developing search strategies
|
||||
target_selection: "llama-3.1-8b-instant" # Use Groq's Llama 3.1 8B for target selection
|
||||
|
||||
# Document ranking module
|
||||
document_ranking:
|
||||
rerank_documents: "jina-reranker" # Use Jina's reranker for document reranking
|
||||
|
||||
# Report generation module
|
||||
report_generation:
|
||||
synthesize_report: "llama-3.3-70b-versatile" # Use Groq's Llama 3.3 70B for report synthesis
|
||||
format_report: "llama-3.1-8b-instant" # Use Groq's Llama 3.1 8B for formatting
|
||||
|
||||
# Search engine configurations
|
||||
search_engines:
|
||||
google:
|
||||
enabled: true
|
||||
max_results: 10
|
||||
|
||||
serper:
|
||||
enabled: true
|
||||
max_results: 10
|
||||
|
||||
jina:
|
||||
enabled: true
|
||||
max_results: 10
|
||||
|
||||
scholar:
|
||||
enabled: false
|
||||
max_results: 5
|
||||
|
||||
arxiv:
|
||||
enabled: false
|
||||
max_results: 5
|
||||
|
||||
# Jina AI specific configurations
|
||||
jina:
|
||||
reranker:
|
||||
model: "jina-reranker-v2-base-multilingual" # Default reranker model
|
||||
top_n: 10 # Default number of top results to return
|
||||
|
||||
# UI configuration
|
||||
ui:
|
||||
theme: "light" # light or dark
|
||||
port: 7860
|
||||
share: false
|
||||
title: "Intelligent Research System"
|
||||
description: "An automated system for finding, filtering, and synthesizing information"
|
||||
|
||||
# System settings
|
||||
system:
|
||||
cache_dir: "data/cache"
|
||||
results_dir: "data/results"
|
||||
log_level: "INFO" # DEBUG, INFO, WARNING, ERROR, CRITICAL
|
|
@ -99,6 +99,11 @@ class LLMInterface:
|
|||
'HTTP-Referer': 'https://sim-search.app', # Replace with your actual app URL
|
||||
'X-Title': 'Intelligent Research System' # Replace with your actual app name
|
||||
}
|
||||
elif provider == 'google':
|
||||
# Special handling for Google Gemini models
|
||||
params['model'] = self.model_config.get('model_name', self.model_name)
|
||||
# Google Gemini uses a different API base
|
||||
params['api_base'] = self.model_config.get('endpoint', 'https://generativelanguage.googleapis.com/v1')
|
||||
else:
|
||||
# Standard provider (OpenAI, Anthropic, etc.)
|
||||
params['model'] = self.model_name
|
||||
|
|
|
@ -56,11 +56,11 @@ class ReportDetailLevelManager:
|
|||
"description": "A comprehensive report with in-depth analysis, methodology, and implications."
|
||||
},
|
||||
DetailLevel.COMPREHENSIVE: {
|
||||
"num_results": 8,
|
||||
"token_budget": 80000,
|
||||
"chunk_size": 800,
|
||||
"overlap_size": 80,
|
||||
"model": "llama-3.3-70b-versatile",
|
||||
"num_results": 12,
|
||||
"token_budget": 200000,
|
||||
"chunk_size": 1200,
|
||||
"overlap_size": 120,
|
||||
"model": "gemini-2.0-flash-lite",
|
||||
"description": "An exhaustive report with all available information, extensive analysis, and detailed references."
|
||||
}
|
||||
}
|
||||
|
|
|
@ -180,9 +180,23 @@ class ReportSynthesizer:
|
|||
total_chunks = len(chunks)
|
||||
logger.info(f"Starting to process {total_chunks} document chunks")
|
||||
|
||||
for i, chunk in enumerate(chunks, 1):
|
||||
# Determine batch size based on the model - Gemini can handle larger batches
|
||||
if "gemini" in self.model_name.lower():
|
||||
batch_size = 8 # Larger batch size for Gemini models with 1M token windows
|
||||
else:
|
||||
batch_size = 3 # Smaller batch size for other models
|
||||
|
||||
logger.info(f"Using batch size of {batch_size} for model {self.model_name}")
|
||||
|
||||
for i in range(0, len(chunks), batch_size):
|
||||
batch = chunks[i:i+batch_size]
|
||||
logger.info(f"Processing batch {i//batch_size + 1}/{(len(chunks) + batch_size - 1)//batch_size} with {len(batch)} chunks")
|
||||
|
||||
# Process this batch
|
||||
batch_results = []
|
||||
for chunk in batch:
|
||||
chunk_title = chunk.get('title', 'Untitled')
|
||||
logger.info(f"Processing chunk {i}/{total_chunks}: {chunk_title[:50]}...")
|
||||
logger.info(f"Processing chunk {i+1}/{total_chunks}: {chunk_title[:50]}...")
|
||||
|
||||
# Create a prompt for extracting key information from the chunk
|
||||
messages = [
|
||||
|
@ -205,15 +219,22 @@ class ReportSynthesizer:
|
|||
# Add the extracted information to the chunk
|
||||
processed_chunk = chunk.copy()
|
||||
processed_chunk['extracted_info'] = extracted_info
|
||||
processed_chunks.append(processed_chunk)
|
||||
batch_results.append(processed_chunk)
|
||||
|
||||
logger.info(f"Completed chunk {i}/{total_chunks} ({(i/total_chunks)*100:.1f}% complete)")
|
||||
logger.info(f"Completed chunk {i+1}/{total_chunks} ({(i+1)/total_chunks*100:.1f}% complete)")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing chunk {i}/{total_chunks}: {str(e)}")
|
||||
logger.error(f"Error processing chunk {i+1}/{total_chunks}: {str(e)}")
|
||||
# Add a placeholder for the failed chunk to maintain document order
|
||||
processed_chunk = chunk.copy()
|
||||
processed_chunk['extracted_info'] = f"Error extracting information: {str(e)}"
|
||||
processed_chunks.append(processed_chunk)
|
||||
batch_results.append(processed_chunk)
|
||||
|
||||
processed_chunks.extend(batch_results)
|
||||
|
||||
# Add a small delay between batches to avoid rate limiting
|
||||
if i + batch_size < len(chunks):
|
||||
logger.info("Pausing briefly between batches...")
|
||||
await asyncio.sleep(2)
|
||||
|
||||
logger.info(f"Completed processing all {total_chunks} chunks")
|
||||
return processed_chunks
|
||||
|
@ -355,7 +376,13 @@ class ReportSynthesizer:
|
|||
logger.info(f"Starting map phase for {len(chunks)} document chunks with query type '{query_type}' and detail level '{detail_level}'")
|
||||
|
||||
# Process chunks in batches to avoid hitting payload limits
|
||||
batch_size = 3 # Process 3 chunks at a time
|
||||
# Determine batch size based on the model - Gemini can handle larger batches
|
||||
if "gemini" in self.model_name.lower():
|
||||
batch_size = 8 # Larger batch size for Gemini models with 1M token windows
|
||||
else:
|
||||
batch_size = 3 # Smaller batch size for other models
|
||||
|
||||
logger.info(f"Using batch size of {batch_size} for model {self.model_name}")
|
||||
processed_chunks = []
|
||||
|
||||
for i in range(0, len(chunks), batch_size):
|
||||
|
|
|
@ -202,7 +202,13 @@ class GradioInterface:
|
|||
# Create a timestamped output file
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
model_suffix = ""
|
||||
|
||||
# Extract the actual model name from the description if selected
|
||||
if custom_model:
|
||||
# If the model is in the format "model_name (provider: model_display)"
|
||||
if "(" in custom_model:
|
||||
custom_model = custom_model.split(" (")[0]
|
||||
|
||||
model_name = custom_model.split('/')[-1]
|
||||
model_suffix = f"_{model_name}"
|
||||
|
||||
|
@ -315,15 +321,53 @@ class GradioInterface:
|
|||
list: List of available model names
|
||||
"""
|
||||
# Get models from config
|
||||
models = []
|
||||
|
||||
# Extract all model names from the config file
|
||||
if 'models' in self.config.config_data:
|
||||
models = list(self.config.config_data['models'].keys())
|
||||
|
||||
# If no models found, provide some defaults
|
||||
if not models:
|
||||
models = [
|
||||
"llama-3.1-8b-instant",
|
||||
"llama-3.3-70b-versatile",
|
||||
"groq/deepseek-r1-distill-llama-70b-specdec",
|
||||
"openrouter-mixtral",
|
||||
"openrouter-claude"
|
||||
"openrouter-claude",
|
||||
"gemini-2.0-flash-lite"
|
||||
]
|
||||
|
||||
return models
|
||||
|
||||
def get_model_descriptions(self):
|
||||
"""
|
||||
Get descriptions for available models.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary mapping model names to descriptions
|
||||
"""
|
||||
descriptions = {}
|
||||
model_name_to_description = {}
|
||||
|
||||
if 'models' in self.config.config_data:
|
||||
for model_name, model_config in self.config.config_data['models'].items():
|
||||
provider = model_config.get('provider', 'unknown')
|
||||
model_display = model_config.get('model_name', model_name)
|
||||
max_tokens = model_config.get('max_tokens', 'unknown')
|
||||
temperature = model_config.get('temperature', 'unknown')
|
||||
|
||||
# Create a description that includes the provider and actual model name
|
||||
display_name = f"{model_name} ({provider}: {model_display})"
|
||||
descriptions[model_name] = display_name
|
||||
|
||||
# Create a more detailed description for the dropdown tooltip
|
||||
detailed_info = f"{display_name} - Max tokens: {max_tokens}, Temperature: {temperature}"
|
||||
model_name_to_description[display_name] = detailed_info
|
||||
|
||||
self.model_name_to_description = model_name_to_description
|
||||
return descriptions
|
||||
|
||||
def create_interface(self):
|
||||
"""
|
||||
Create and return the Gradio interface.
|
||||
|
@ -401,8 +445,9 @@ class GradioInterface:
|
|||
label="Detail Level",
|
||||
info="Controls the depth and breadth of the report"
|
||||
)
|
||||
model_descriptions = self.get_model_descriptions()
|
||||
report_custom_model = gr.Dropdown(
|
||||
choices=self.get_available_models(),
|
||||
choices=list(self.model_name_to_description.keys()),
|
||||
value=None,
|
||||
label="Custom Model (Optional)",
|
||||
info="Select a custom model for report generation"
|
||||
|
|
Loading…
Reference in New Issue