Add Google Gemini model support and dynamic model selection in UI
This commit is contained in:
parent
76cedb9528
commit
f6f660da97
|
@ -74,6 +74,10 @@ class Config:
|
||||||
if provider.lower() == 'openrouter':
|
if provider.lower() == 'openrouter':
|
||||||
env_var_name = "OPENROUTER_API_KEY"
|
env_var_name = "OPENROUTER_API_KEY"
|
||||||
|
|
||||||
|
# Special case for Google which might use GOOGLE_API_KEY
|
||||||
|
if provider.lower() == 'google':
|
||||||
|
env_var_name = "GOOGLE_API_KEY"
|
||||||
|
|
||||||
api_key = os.environ.get(env_var_name)
|
api_key = os.environ.get(env_var_name)
|
||||||
|
|
||||||
# If not in environment, check config file
|
# If not in environment, check config file
|
||||||
|
|
|
@ -0,0 +1,158 @@
|
||||||
|
# Example configuration file for the intelligent research system
|
||||||
|
# Rename this file to config.yaml and fill in your API keys and settings
|
||||||
|
|
||||||
|
# API keys (alternatively, set environment variables)
|
||||||
|
api_keys:
|
||||||
|
openai: "your-openai-api-key" # Or set OPENAI_API_KEY environment variable
|
||||||
|
jina: "your-jina-api-key" # Or set JINA_API_KEY environment variable
|
||||||
|
serper: "your-serper-api-key" # Or set SERPER_API_KEY environment variable
|
||||||
|
google: "your-google-api-key" # Or set GOOGLE_API_KEY environment variable
|
||||||
|
anthropic: "your-anthropic-api-key" # Or set ANTHROPIC_API_KEY environment variable
|
||||||
|
openrouter: "your-openrouter-api-key" # Or set OPENROUTER_API_KEY environment variable
|
||||||
|
groq: "your-groq-api-key" # Or set GROQ_API_KEY environment variable
|
||||||
|
|
||||||
|
# LLM model configurations
|
||||||
|
models:
|
||||||
|
gpt-3.5-turbo:
|
||||||
|
provider: "openai"
|
||||||
|
temperature: 0.7
|
||||||
|
max_tokens: 1000
|
||||||
|
top_p: 1.0
|
||||||
|
endpoint: null # Use default OpenAI endpoint
|
||||||
|
|
||||||
|
gpt-4:
|
||||||
|
provider: "openai"
|
||||||
|
temperature: 0.5
|
||||||
|
max_tokens: 2000
|
||||||
|
top_p: 1.0
|
||||||
|
endpoint: null # Use default OpenAI endpoint
|
||||||
|
|
||||||
|
claude-2:
|
||||||
|
provider: "anthropic"
|
||||||
|
temperature: 0.7
|
||||||
|
max_tokens: 1500
|
||||||
|
top_p: 1.0
|
||||||
|
endpoint: null # Use default Anthropic endpoint
|
||||||
|
|
||||||
|
azure-gpt-4:
|
||||||
|
provider: "azure"
|
||||||
|
temperature: 0.5
|
||||||
|
max_tokens: 2000
|
||||||
|
top_p: 1.0
|
||||||
|
endpoint: "https://your-azure-endpoint.openai.azure.com"
|
||||||
|
deployment_name: "your-deployment-name"
|
||||||
|
api_version: "2023-05-15"
|
||||||
|
|
||||||
|
local-llama:
|
||||||
|
provider: "ollama"
|
||||||
|
temperature: 0.8
|
||||||
|
max_tokens: 1000
|
||||||
|
endpoint: "http://localhost:11434/api/generate"
|
||||||
|
model_name: "llama2"
|
||||||
|
|
||||||
|
llama-3.1-8b-instant:
|
||||||
|
provider: "groq"
|
||||||
|
model_name: "llama-3.1-8b-instant"
|
||||||
|
temperature: 0.7
|
||||||
|
max_tokens: 1024
|
||||||
|
top_p: 1.0
|
||||||
|
endpoint: "https://api.groq.com/openai/v1"
|
||||||
|
|
||||||
|
llama-3.3-70b-versatile:
|
||||||
|
provider: "groq"
|
||||||
|
model_name: "llama-3.3-70b-versatile"
|
||||||
|
temperature: 0.5
|
||||||
|
max_tokens: 2048
|
||||||
|
top_p: 1.0
|
||||||
|
endpoint: "https://api.groq.com/openai/v1"
|
||||||
|
|
||||||
|
openrouter-mixtral:
|
||||||
|
provider: "openrouter"
|
||||||
|
model_name: "mistralai/mixtral-8x7b-instruct"
|
||||||
|
temperature: 0.7
|
||||||
|
max_tokens: 1024
|
||||||
|
top_p: 1.0
|
||||||
|
endpoint: "https://openrouter.ai/api/v1"
|
||||||
|
|
||||||
|
openrouter-claude:
|
||||||
|
provider: "openrouter"
|
||||||
|
model_name: "anthropic/claude-3-opus"
|
||||||
|
temperature: 0.5
|
||||||
|
max_tokens: 2048
|
||||||
|
top_p: 1.0
|
||||||
|
endpoint: "https://openrouter.ai/api/v1"
|
||||||
|
|
||||||
|
gemini-2.0-flash-lite:
|
||||||
|
provider: "google"
|
||||||
|
model_name: "google/gemini-2.0-flash-lite-001"
|
||||||
|
temperature: 0.5
|
||||||
|
max_tokens: 2048
|
||||||
|
top_p: 1.0
|
||||||
|
endpoint: "https://generativelanguage.googleapis.com/v1"
|
||||||
|
|
||||||
|
# Default model to use if not specified for a module
|
||||||
|
default_model: "llama-3.1-8b-instant" # Using Groq's Llama 3.1 8B model for testing
|
||||||
|
|
||||||
|
# Module-specific model assignments
|
||||||
|
module_models:
|
||||||
|
# Query processing module
|
||||||
|
query_processing:
|
||||||
|
enhance_query: "llama-3.1-8b-instant" # Use Groq's Llama 3.1 8B for query enhancement
|
||||||
|
classify_query: "llama-3.1-8b-instant" # Use Groq's Llama 3.1 8B for classification
|
||||||
|
generate_search_queries: "llama-3.1-8b-instant" # Use Groq's Llama 3.1 8B for generating search queries
|
||||||
|
|
||||||
|
# Search strategy module
|
||||||
|
search_strategy:
|
||||||
|
develop_strategy: "llama-3.1-8b-instant" # Use Groq's Llama 3.1 8B for developing search strategies
|
||||||
|
target_selection: "llama-3.1-8b-instant" # Use Groq's Llama 3.1 8B for target selection
|
||||||
|
|
||||||
|
# Document ranking module
|
||||||
|
document_ranking:
|
||||||
|
rerank_documents: "jina-reranker" # Use Jina's reranker for document reranking
|
||||||
|
|
||||||
|
# Report generation module
|
||||||
|
report_generation:
|
||||||
|
synthesize_report: "llama-3.3-70b-versatile" # Use Groq's Llama 3.3 70B for report synthesis
|
||||||
|
format_report: "llama-3.1-8b-instant" # Use Groq's Llama 3.1 8B for formatting
|
||||||
|
|
||||||
|
# Search engine configurations
|
||||||
|
search_engines:
|
||||||
|
google:
|
||||||
|
enabled: true
|
||||||
|
max_results: 10
|
||||||
|
|
||||||
|
serper:
|
||||||
|
enabled: true
|
||||||
|
max_results: 10
|
||||||
|
|
||||||
|
jina:
|
||||||
|
enabled: true
|
||||||
|
max_results: 10
|
||||||
|
|
||||||
|
scholar:
|
||||||
|
enabled: false
|
||||||
|
max_results: 5
|
||||||
|
|
||||||
|
arxiv:
|
||||||
|
enabled: false
|
||||||
|
max_results: 5
|
||||||
|
|
||||||
|
# Jina AI specific configurations
|
||||||
|
jina:
|
||||||
|
reranker:
|
||||||
|
model: "jina-reranker-v2-base-multilingual" # Default reranker model
|
||||||
|
top_n: 10 # Default number of top results to return
|
||||||
|
|
||||||
|
# UI configuration
|
||||||
|
ui:
|
||||||
|
theme: "light" # light or dark
|
||||||
|
port: 7860
|
||||||
|
share: false
|
||||||
|
title: "Intelligent Research System"
|
||||||
|
description: "An automated system for finding, filtering, and synthesizing information"
|
||||||
|
|
||||||
|
# System settings
|
||||||
|
system:
|
||||||
|
cache_dir: "data/cache"
|
||||||
|
results_dir: "data/results"
|
||||||
|
log_level: "INFO" # DEBUG, INFO, WARNING, ERROR, CRITICAL
|
|
@ -99,6 +99,11 @@ class LLMInterface:
|
||||||
'HTTP-Referer': 'https://sim-search.app', # Replace with your actual app URL
|
'HTTP-Referer': 'https://sim-search.app', # Replace with your actual app URL
|
||||||
'X-Title': 'Intelligent Research System' # Replace with your actual app name
|
'X-Title': 'Intelligent Research System' # Replace with your actual app name
|
||||||
}
|
}
|
||||||
|
elif provider == 'google':
|
||||||
|
# Special handling for Google Gemini models
|
||||||
|
params['model'] = self.model_config.get('model_name', self.model_name)
|
||||||
|
# Google Gemini uses a different API base
|
||||||
|
params['api_base'] = self.model_config.get('endpoint', 'https://generativelanguage.googleapis.com/v1')
|
||||||
else:
|
else:
|
||||||
# Standard provider (OpenAI, Anthropic, etc.)
|
# Standard provider (OpenAI, Anthropic, etc.)
|
||||||
params['model'] = self.model_name
|
params['model'] = self.model_name
|
||||||
|
|
|
@ -56,11 +56,11 @@ class ReportDetailLevelManager:
|
||||||
"description": "A comprehensive report with in-depth analysis, methodology, and implications."
|
"description": "A comprehensive report with in-depth analysis, methodology, and implications."
|
||||||
},
|
},
|
||||||
DetailLevel.COMPREHENSIVE: {
|
DetailLevel.COMPREHENSIVE: {
|
||||||
"num_results": 8,
|
"num_results": 12,
|
||||||
"token_budget": 80000,
|
"token_budget": 200000,
|
||||||
"chunk_size": 800,
|
"chunk_size": 1200,
|
||||||
"overlap_size": 80,
|
"overlap_size": 120,
|
||||||
"model": "llama-3.3-70b-versatile",
|
"model": "gemini-2.0-flash-lite",
|
||||||
"description": "An exhaustive report with all available information, extensive analysis, and detailed references."
|
"description": "An exhaustive report with all available information, extensive analysis, and detailed references."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -180,40 +180,61 @@ class ReportSynthesizer:
|
||||||
total_chunks = len(chunks)
|
total_chunks = len(chunks)
|
||||||
logger.info(f"Starting to process {total_chunks} document chunks")
|
logger.info(f"Starting to process {total_chunks} document chunks")
|
||||||
|
|
||||||
for i, chunk in enumerate(chunks, 1):
|
# Determine batch size based on the model - Gemini can handle larger batches
|
||||||
chunk_title = chunk.get('title', 'Untitled')
|
if "gemini" in self.model_name.lower():
|
||||||
logger.info(f"Processing chunk {i}/{total_chunks}: {chunk_title[:50]}...")
|
batch_size = 8 # Larger batch size for Gemini models with 1M token windows
|
||||||
|
else:
|
||||||
|
batch_size = 3 # Smaller batch size for other models
|
||||||
|
|
||||||
# Create a prompt for extracting key information from the chunk
|
logger.info(f"Using batch size of {batch_size} for model {self.model_name}")
|
||||||
messages = [
|
|
||||||
{"role": "system", "content": extraction_prompt},
|
for i in range(0, len(chunks), batch_size):
|
||||||
{"role": "user", "content": f"""Query: {query}
|
batch = chunks[i:i+batch_size]
|
||||||
|
logger.info(f"Processing batch {i//batch_size + 1}/{(len(chunks) + batch_size - 1)//batch_size} with {len(batch)} chunks")
|
||||||
Document title: {chunk.get('title', 'Untitled')}
|
|
||||||
Document URL: {chunk.get('url', 'Unknown')}
|
|
||||||
|
|
||||||
Document chunk content:
|
|
||||||
{chunk.get('content', '')}
|
|
||||||
|
|
||||||
Extract the most relevant information from this document chunk that addresses the query."""}
|
|
||||||
]
|
|
||||||
|
|
||||||
try:
|
# Process this batch
|
||||||
# Process the chunk with the LLM
|
batch_results = []
|
||||||
extracted_info = await self.generate_completion(messages)
|
for chunk in batch:
|
||||||
|
chunk_title = chunk.get('title', 'Untitled')
|
||||||
|
logger.info(f"Processing chunk {i+1}/{total_chunks}: {chunk_title[:50]}...")
|
||||||
|
|
||||||
# Add the extracted information to the chunk
|
# Create a prompt for extracting key information from the chunk
|
||||||
processed_chunk = chunk.copy()
|
messages = [
|
||||||
processed_chunk['extracted_info'] = extracted_info
|
{"role": "system", "content": extraction_prompt},
|
||||||
processed_chunks.append(processed_chunk)
|
{"role": "user", "content": f"""Query: {query}
|
||||||
|
|
||||||
|
Document title: {chunk.get('title', 'Untitled')}
|
||||||
|
Document URL: {chunk.get('url', 'Unknown')}
|
||||||
|
|
||||||
|
Document chunk content:
|
||||||
|
{chunk.get('content', '')}
|
||||||
|
|
||||||
|
Extract the most relevant information from this document chunk that addresses the query."""}
|
||||||
|
]
|
||||||
|
|
||||||
logger.info(f"Completed chunk {i}/{total_chunks} ({(i/total_chunks)*100:.1f}% complete)")
|
try:
|
||||||
except Exception as e:
|
# Process the chunk with the LLM
|
||||||
logger.error(f"Error processing chunk {i}/{total_chunks}: {str(e)}")
|
extracted_info = await self.generate_completion(messages)
|
||||||
# Add a placeholder for the failed chunk to maintain document order
|
|
||||||
processed_chunk = chunk.copy()
|
# Add the extracted information to the chunk
|
||||||
processed_chunk['extracted_info'] = f"Error extracting information: {str(e)}"
|
processed_chunk = chunk.copy()
|
||||||
processed_chunks.append(processed_chunk)
|
processed_chunk['extracted_info'] = extracted_info
|
||||||
|
batch_results.append(processed_chunk)
|
||||||
|
|
||||||
|
logger.info(f"Completed chunk {i+1}/{total_chunks} ({(i+1)/total_chunks*100:.1f}% complete)")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing chunk {i+1}/{total_chunks}: {str(e)}")
|
||||||
|
# Add a placeholder for the failed chunk to maintain document order
|
||||||
|
processed_chunk = chunk.copy()
|
||||||
|
processed_chunk['extracted_info'] = f"Error extracting information: {str(e)}"
|
||||||
|
batch_results.append(processed_chunk)
|
||||||
|
|
||||||
|
processed_chunks.extend(batch_results)
|
||||||
|
|
||||||
|
# Add a small delay between batches to avoid rate limiting
|
||||||
|
if i + batch_size < len(chunks):
|
||||||
|
logger.info("Pausing briefly between batches...")
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
logger.info(f"Completed processing all {total_chunks} chunks")
|
logger.info(f"Completed processing all {total_chunks} chunks")
|
||||||
return processed_chunks
|
return processed_chunks
|
||||||
|
@ -355,7 +376,13 @@ class ReportSynthesizer:
|
||||||
logger.info(f"Starting map phase for {len(chunks)} document chunks with query type '{query_type}' and detail level '{detail_level}'")
|
logger.info(f"Starting map phase for {len(chunks)} document chunks with query type '{query_type}' and detail level '{detail_level}'")
|
||||||
|
|
||||||
# Process chunks in batches to avoid hitting payload limits
|
# Process chunks in batches to avoid hitting payload limits
|
||||||
batch_size = 3 # Process 3 chunks at a time
|
# Determine batch size based on the model - Gemini can handle larger batches
|
||||||
|
if "gemini" in self.model_name.lower():
|
||||||
|
batch_size = 8 # Larger batch size for Gemini models with 1M token windows
|
||||||
|
else:
|
||||||
|
batch_size = 3 # Smaller batch size for other models
|
||||||
|
|
||||||
|
logger.info(f"Using batch size of {batch_size} for model {self.model_name}")
|
||||||
processed_chunks = []
|
processed_chunks = []
|
||||||
|
|
||||||
for i in range(0, len(chunks), batch_size):
|
for i in range(0, len(chunks), batch_size):
|
||||||
|
|
|
@ -202,7 +202,13 @@ class GradioInterface:
|
||||||
# Create a timestamped output file
|
# Create a timestamped output file
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
model_suffix = ""
|
model_suffix = ""
|
||||||
|
|
||||||
|
# Extract the actual model name from the description if selected
|
||||||
if custom_model:
|
if custom_model:
|
||||||
|
# If the model is in the format "model_name (provider: model_display)"
|
||||||
|
if "(" in custom_model:
|
||||||
|
custom_model = custom_model.split(" (")[0]
|
||||||
|
|
||||||
model_name = custom_model.split('/')[-1]
|
model_name = custom_model.split('/')[-1]
|
||||||
model_suffix = f"_{model_name}"
|
model_suffix = f"_{model_name}"
|
||||||
|
|
||||||
|
@ -315,15 +321,53 @@ class GradioInterface:
|
||||||
list: List of available model names
|
list: List of available model names
|
||||||
"""
|
"""
|
||||||
# Get models from config
|
# Get models from config
|
||||||
models = [
|
models = []
|
||||||
"llama-3.1-8b-instant",
|
|
||||||
"llama-3.3-70b-versatile",
|
# Extract all model names from the config file
|
||||||
"groq/deepseek-r1-distill-llama-70b-specdec",
|
if 'models' in self.config.config_data:
|
||||||
"openrouter-mixtral",
|
models = list(self.config.config_data['models'].keys())
|
||||||
"openrouter-claude"
|
|
||||||
]
|
# If no models found, provide some defaults
|
||||||
|
if not models:
|
||||||
|
models = [
|
||||||
|
"llama-3.1-8b-instant",
|
||||||
|
"llama-3.3-70b-versatile",
|
||||||
|
"groq/deepseek-r1-distill-llama-70b-specdec",
|
||||||
|
"openrouter-mixtral",
|
||||||
|
"openrouter-claude",
|
||||||
|
"gemini-2.0-flash-lite"
|
||||||
|
]
|
||||||
|
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
def get_model_descriptions(self):
|
||||||
|
"""
|
||||||
|
Get descriptions for available models.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Dictionary mapping model names to descriptions
|
||||||
|
"""
|
||||||
|
descriptions = {}
|
||||||
|
model_name_to_description = {}
|
||||||
|
|
||||||
|
if 'models' in self.config.config_data:
|
||||||
|
for model_name, model_config in self.config.config_data['models'].items():
|
||||||
|
provider = model_config.get('provider', 'unknown')
|
||||||
|
model_display = model_config.get('model_name', model_name)
|
||||||
|
max_tokens = model_config.get('max_tokens', 'unknown')
|
||||||
|
temperature = model_config.get('temperature', 'unknown')
|
||||||
|
|
||||||
|
# Create a description that includes the provider and actual model name
|
||||||
|
display_name = f"{model_name} ({provider}: {model_display})"
|
||||||
|
descriptions[model_name] = display_name
|
||||||
|
|
||||||
|
# Create a more detailed description for the dropdown tooltip
|
||||||
|
detailed_info = f"{display_name} - Max tokens: {max_tokens}, Temperature: {temperature}"
|
||||||
|
model_name_to_description[display_name] = detailed_info
|
||||||
|
|
||||||
|
self.model_name_to_description = model_name_to_description
|
||||||
|
return descriptions
|
||||||
|
|
||||||
def create_interface(self):
|
def create_interface(self):
|
||||||
"""
|
"""
|
||||||
Create and return the Gradio interface.
|
Create and return the Gradio interface.
|
||||||
|
@ -401,8 +445,9 @@ class GradioInterface:
|
||||||
label="Detail Level",
|
label="Detail Level",
|
||||||
info="Controls the depth and breadth of the report"
|
info="Controls the depth and breadth of the report"
|
||||||
)
|
)
|
||||||
|
model_descriptions = self.get_model_descriptions()
|
||||||
report_custom_model = gr.Dropdown(
|
report_custom_model = gr.Dropdown(
|
||||||
choices=self.get_available_models(),
|
choices=list(self.model_name_to_description.keys()),
|
||||||
value=None,
|
value=None,
|
||||||
label="Custom Model (Optional)",
|
label="Custom Model (Optional)",
|
||||||
info="Select a custom model for report generation"
|
info="Select a custom model for report generation"
|
||||||
|
|
Loading…
Reference in New Issue