From f6f660da974d4949cfb302a095f95b485f431bf7 Mon Sep 17 00:00:00 2001 From: Steve White Date: Fri, 28 Feb 2025 11:31:53 -0600 Subject: [PATCH] Add Google Gemini model support and dynamic model selection in UI --- config/config.py | 4 + config/config.yaml | 158 +++++++++++++++++++++++++++++++++ query/llm_interface.py | 5 ++ report/report_detail_levels.py | 10 +-- report/report_synthesis.py | 89 ++++++++++++------- ui/gradio_interface.py | 61 +++++++++++-- 6 files changed, 283 insertions(+), 44 deletions(-) create mode 100644 config/config.yaml diff --git a/config/config.py b/config/config.py index 5f7f3b5..9be4a8e 100644 --- a/config/config.py +++ b/config/config.py @@ -74,6 +74,10 @@ class Config: if provider.lower() == 'openrouter': env_var_name = "OPENROUTER_API_KEY" + # Special case for Google which might use GOOGLE_API_KEY + if provider.lower() == 'google': + env_var_name = "GOOGLE_API_KEY" + api_key = os.environ.get(env_var_name) # If not in environment, check config file diff --git a/config/config.yaml b/config/config.yaml new file mode 100644 index 0000000..b5e945e --- /dev/null +++ b/config/config.yaml @@ -0,0 +1,158 @@ +# Example configuration file for the intelligent research system +# Rename this file to config.yaml and fill in your API keys and settings + +# API keys (alternatively, set environment variables) +api_keys: + openai: "your-openai-api-key" # Or set OPENAI_API_KEY environment variable + jina: "your-jina-api-key" # Or set JINA_API_KEY environment variable + serper: "your-serper-api-key" # Or set SERPER_API_KEY environment variable + google: "your-google-api-key" # Or set GOOGLE_API_KEY environment variable + anthropic: "your-anthropic-api-key" # Or set ANTHROPIC_API_KEY environment variable + openrouter: "your-openrouter-api-key" # Or set OPENROUTER_API_KEY environment variable + groq: "your-groq-api-key" # Or set GROQ_API_KEY environment variable + +# LLM model configurations +models: + gpt-3.5-turbo: + provider: "openai" + temperature: 0.7 + max_tokens: 1000 + top_p: 1.0 + endpoint: null # Use default OpenAI endpoint + + gpt-4: + provider: "openai" + temperature: 0.5 + max_tokens: 2000 + top_p: 1.0 + endpoint: null # Use default OpenAI endpoint + + claude-2: + provider: "anthropic" + temperature: 0.7 + max_tokens: 1500 + top_p: 1.0 + endpoint: null # Use default Anthropic endpoint + + azure-gpt-4: + provider: "azure" + temperature: 0.5 + max_tokens: 2000 + top_p: 1.0 + endpoint: "https://your-azure-endpoint.openai.azure.com" + deployment_name: "your-deployment-name" + api_version: "2023-05-15" + + local-llama: + provider: "ollama" + temperature: 0.8 + max_tokens: 1000 + endpoint: "http://localhost:11434/api/generate" + model_name: "llama2" + + llama-3.1-8b-instant: + provider: "groq" + model_name: "llama-3.1-8b-instant" + temperature: 0.7 + max_tokens: 1024 + top_p: 1.0 + endpoint: "https://api.groq.com/openai/v1" + + llama-3.3-70b-versatile: + provider: "groq" + model_name: "llama-3.3-70b-versatile" + temperature: 0.5 + max_tokens: 2048 + top_p: 1.0 + endpoint: "https://api.groq.com/openai/v1" + + openrouter-mixtral: + provider: "openrouter" + model_name: "mistralai/mixtral-8x7b-instruct" + temperature: 0.7 + max_tokens: 1024 + top_p: 1.0 + endpoint: "https://openrouter.ai/api/v1" + + openrouter-claude: + provider: "openrouter" + model_name: "anthropic/claude-3-opus" + temperature: 0.5 + max_tokens: 2048 + top_p: 1.0 + endpoint: "https://openrouter.ai/api/v1" + + gemini-2.0-flash-lite: + provider: "google" + model_name: "google/gemini-2.0-flash-lite-001" + temperature: 0.5 + max_tokens: 2048 + top_p: 1.0 + endpoint: "https://generativelanguage.googleapis.com/v1" + +# Default model to use if not specified for a module +default_model: "llama-3.1-8b-instant" # Using Groq's Llama 3.1 8B model for testing + +# Module-specific model assignments +module_models: + # Query processing module + query_processing: + enhance_query: "llama-3.1-8b-instant" # Use Groq's Llama 3.1 8B for query enhancement + classify_query: "llama-3.1-8b-instant" # Use Groq's Llama 3.1 8B for classification + generate_search_queries: "llama-3.1-8b-instant" # Use Groq's Llama 3.1 8B for generating search queries + + # Search strategy module + search_strategy: + develop_strategy: "llama-3.1-8b-instant" # Use Groq's Llama 3.1 8B for developing search strategies + target_selection: "llama-3.1-8b-instant" # Use Groq's Llama 3.1 8B for target selection + + # Document ranking module + document_ranking: + rerank_documents: "jina-reranker" # Use Jina's reranker for document reranking + + # Report generation module + report_generation: + synthesize_report: "llama-3.3-70b-versatile" # Use Groq's Llama 3.3 70B for report synthesis + format_report: "llama-3.1-8b-instant" # Use Groq's Llama 3.1 8B for formatting + +# Search engine configurations +search_engines: + google: + enabled: true + max_results: 10 + + serper: + enabled: true + max_results: 10 + + jina: + enabled: true + max_results: 10 + + scholar: + enabled: false + max_results: 5 + + arxiv: + enabled: false + max_results: 5 + +# Jina AI specific configurations +jina: + reranker: + model: "jina-reranker-v2-base-multilingual" # Default reranker model + top_n: 10 # Default number of top results to return + +# UI configuration +ui: + theme: "light" # light or dark + port: 7860 + share: false + title: "Intelligent Research System" + description: "An automated system for finding, filtering, and synthesizing information" + +# System settings +system: + cache_dir: "data/cache" + results_dir: "data/results" + log_level: "INFO" # DEBUG, INFO, WARNING, ERROR, CRITICAL diff --git a/query/llm_interface.py b/query/llm_interface.py index 70059df..c87b4cd 100644 --- a/query/llm_interface.py +++ b/query/llm_interface.py @@ -99,6 +99,11 @@ class LLMInterface: 'HTTP-Referer': 'https://sim-search.app', # Replace with your actual app URL 'X-Title': 'Intelligent Research System' # Replace with your actual app name } + elif provider == 'google': + # Special handling for Google Gemini models + params['model'] = self.model_config.get('model_name', self.model_name) + # Google Gemini uses a different API base + params['api_base'] = self.model_config.get('endpoint', 'https://generativelanguage.googleapis.com/v1') else: # Standard provider (OpenAI, Anthropic, etc.) params['model'] = self.model_name diff --git a/report/report_detail_levels.py b/report/report_detail_levels.py index 2340842..c70be47 100644 --- a/report/report_detail_levels.py +++ b/report/report_detail_levels.py @@ -56,11 +56,11 @@ class ReportDetailLevelManager: "description": "A comprehensive report with in-depth analysis, methodology, and implications." }, DetailLevel.COMPREHENSIVE: { - "num_results": 8, - "token_budget": 80000, - "chunk_size": 800, - "overlap_size": 80, - "model": "llama-3.3-70b-versatile", + "num_results": 12, + "token_budget": 200000, + "chunk_size": 1200, + "overlap_size": 120, + "model": "gemini-2.0-flash-lite", "description": "An exhaustive report with all available information, extensive analysis, and detailed references." } } diff --git a/report/report_synthesis.py b/report/report_synthesis.py index 1acdfb7..e9cb071 100644 --- a/report/report_synthesis.py +++ b/report/report_synthesis.py @@ -180,40 +180,61 @@ class ReportSynthesizer: total_chunks = len(chunks) logger.info(f"Starting to process {total_chunks} document chunks") - for i, chunk in enumerate(chunks, 1): - chunk_title = chunk.get('title', 'Untitled') - logger.info(f"Processing chunk {i}/{total_chunks}: {chunk_title[:50]}...") + # Determine batch size based on the model - Gemini can handle larger batches + if "gemini" in self.model_name.lower(): + batch_size = 8 # Larger batch size for Gemini models with 1M token windows + else: + batch_size = 3 # Smaller batch size for other models - # Create a prompt for extracting key information from the chunk - messages = [ - {"role": "system", "content": extraction_prompt}, - {"role": "user", "content": f"""Query: {query} - - Document title: {chunk.get('title', 'Untitled')} - Document URL: {chunk.get('url', 'Unknown')} - - Document chunk content: - {chunk.get('content', '')} - - Extract the most relevant information from this document chunk that addresses the query."""} - ] + logger.info(f"Using batch size of {batch_size} for model {self.model_name}") + + for i in range(0, len(chunks), batch_size): + batch = chunks[i:i+batch_size] + logger.info(f"Processing batch {i//batch_size + 1}/{(len(chunks) + batch_size - 1)//batch_size} with {len(batch)} chunks") - try: - # Process the chunk with the LLM - extracted_info = await self.generate_completion(messages) + # Process this batch + batch_results = [] + for chunk in batch: + chunk_title = chunk.get('title', 'Untitled') + logger.info(f"Processing chunk {i+1}/{total_chunks}: {chunk_title[:50]}...") - # Add the extracted information to the chunk - processed_chunk = chunk.copy() - processed_chunk['extracted_info'] = extracted_info - processed_chunks.append(processed_chunk) + # Create a prompt for extracting key information from the chunk + messages = [ + {"role": "system", "content": extraction_prompt}, + {"role": "user", "content": f"""Query: {query} + + Document title: {chunk.get('title', 'Untitled')} + Document URL: {chunk.get('url', 'Unknown')} + + Document chunk content: + {chunk.get('content', '')} + + Extract the most relevant information from this document chunk that addresses the query."""} + ] - logger.info(f"Completed chunk {i}/{total_chunks} ({(i/total_chunks)*100:.1f}% complete)") - except Exception as e: - logger.error(f"Error processing chunk {i}/{total_chunks}: {str(e)}") - # Add a placeholder for the failed chunk to maintain document order - processed_chunk = chunk.copy() - processed_chunk['extracted_info'] = f"Error extracting information: {str(e)}" - processed_chunks.append(processed_chunk) + try: + # Process the chunk with the LLM + extracted_info = await self.generate_completion(messages) + + # Add the extracted information to the chunk + processed_chunk = chunk.copy() + processed_chunk['extracted_info'] = extracted_info + batch_results.append(processed_chunk) + + logger.info(f"Completed chunk {i+1}/{total_chunks} ({(i+1)/total_chunks*100:.1f}% complete)") + except Exception as e: + logger.error(f"Error processing chunk {i+1}/{total_chunks}: {str(e)}") + # Add a placeholder for the failed chunk to maintain document order + processed_chunk = chunk.copy() + processed_chunk['extracted_info'] = f"Error extracting information: {str(e)}" + batch_results.append(processed_chunk) + + processed_chunks.extend(batch_results) + + # Add a small delay between batches to avoid rate limiting + if i + batch_size < len(chunks): + logger.info("Pausing briefly between batches...") + await asyncio.sleep(2) logger.info(f"Completed processing all {total_chunks} chunks") return processed_chunks @@ -355,7 +376,13 @@ class ReportSynthesizer: logger.info(f"Starting map phase for {len(chunks)} document chunks with query type '{query_type}' and detail level '{detail_level}'") # Process chunks in batches to avoid hitting payload limits - batch_size = 3 # Process 3 chunks at a time + # Determine batch size based on the model - Gemini can handle larger batches + if "gemini" in self.model_name.lower(): + batch_size = 8 # Larger batch size for Gemini models with 1M token windows + else: + batch_size = 3 # Smaller batch size for other models + + logger.info(f"Using batch size of {batch_size} for model {self.model_name}") processed_chunks = [] for i in range(0, len(chunks), batch_size): diff --git a/ui/gradio_interface.py b/ui/gradio_interface.py index eabb231..9f28a73 100644 --- a/ui/gradio_interface.py +++ b/ui/gradio_interface.py @@ -202,7 +202,13 @@ class GradioInterface: # Create a timestamped output file timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") model_suffix = "" + + # Extract the actual model name from the description if selected if custom_model: + # If the model is in the format "model_name (provider: model_display)" + if "(" in custom_model: + custom_model = custom_model.split(" (")[0] + model_name = custom_model.split('/')[-1] model_suffix = f"_{model_name}" @@ -315,15 +321,53 @@ class GradioInterface: list: List of available model names """ # Get models from config - models = [ - "llama-3.1-8b-instant", - "llama-3.3-70b-versatile", - "groq/deepseek-r1-distill-llama-70b-specdec", - "openrouter-mixtral", - "openrouter-claude" - ] + models = [] + + # Extract all model names from the config file + if 'models' in self.config.config_data: + models = list(self.config.config_data['models'].keys()) + + # If no models found, provide some defaults + if not models: + models = [ + "llama-3.1-8b-instant", + "llama-3.3-70b-versatile", + "groq/deepseek-r1-distill-llama-70b-specdec", + "openrouter-mixtral", + "openrouter-claude", + "gemini-2.0-flash-lite" + ] + return models + def get_model_descriptions(self): + """ + Get descriptions for available models. + + Returns: + dict: Dictionary mapping model names to descriptions + """ + descriptions = {} + model_name_to_description = {} + + if 'models' in self.config.config_data: + for model_name, model_config in self.config.config_data['models'].items(): + provider = model_config.get('provider', 'unknown') + model_display = model_config.get('model_name', model_name) + max_tokens = model_config.get('max_tokens', 'unknown') + temperature = model_config.get('temperature', 'unknown') + + # Create a description that includes the provider and actual model name + display_name = f"{model_name} ({provider}: {model_display})" + descriptions[model_name] = display_name + + # Create a more detailed description for the dropdown tooltip + detailed_info = f"{display_name} - Max tokens: {max_tokens}, Temperature: {temperature}" + model_name_to_description[display_name] = detailed_info + + self.model_name_to_description = model_name_to_description + return descriptions + def create_interface(self): """ Create and return the Gradio interface. @@ -401,8 +445,9 @@ class GradioInterface: label="Detail Level", info="Controls the depth and breadth of the report" ) + model_descriptions = self.get_model_descriptions() report_custom_model = gr.Dropdown( - choices=self.get_available_models(), + choices=list(self.model_name_to_description.keys()), value=None, label="Custom Model (Optional)", info="Select a custom model for report generation"