From d76cd9d79bcb07522b58557ab27ccd0dafcbfd53 Mon Sep 17 00:00:00 2001 From: Steve White Date: Tue, 18 Mar 2025 17:31:40 -0500 Subject: [PATCH] Fix model selection in report generation to ensure UI-selected model is properly used throughout the entire report generation pipeline --- report/report_generator.py | 48 +++++++++++++++++++++++++--- report/report_synthesis.py | 39 +++++++++++++++++++++-- ui/gradio_interface.py | 65 ++++++++++++++++++++++++++++++++++---- 3 files changed, 137 insertions(+), 15 deletions(-) diff --git a/report/report_generator.py b/report/report_generator.py index b40202c..427135a 100644 --- a/report/report_generator.py +++ b/report/report_generator.py @@ -50,6 +50,27 @@ class ReportGenerator: await initialize_database() logger.info("Report generator initialized") + def set_model(self, model_name: str) -> None: + """ + Set the model to use for report generation. + + Args: + model_name: Name of the model to use + """ + if model_name and model_name != self.model_name: + self.model_name = model_name + + # Update synthesizers to use this model + self.report_synthesizer = get_report_synthesizer(model_name) + self.progressive_report_synthesizer = get_progressive_report_synthesizer(model_name) + + # Also update the model in the current detail level config + config = self.get_detail_level_config() + if config: + config["model"] = model_name + + logger.info(f"Set model to {model_name}") + def set_detail_level(self, detail_level: str) -> None: """ Set the detail level for report generation. @@ -62,14 +83,14 @@ class ReportGenerator: config = self.detail_level_manager.get_detail_level_config(detail_level) self.detail_level = detail_level - # Update model if needed + # Update model if needed and no custom model is set model = config.get("model") - if model and model != self.model_name: + if model and model != self.model_name and not self.model_name: self.model_name = model self.report_synthesizer = get_report_synthesizer(model) self.progressive_report_synthesizer = get_progressive_report_synthesizer(model) - logger.info(f"Detail level set to {detail_level} with model {model}") + logger.info(f"Detail level set to {detail_level} with model {self.model_name or model}") except ValueError as e: logger.error(f"Error setting detail level: {e}") raise @@ -236,6 +257,17 @@ class ReportGenerator: if detail_level: self.set_detail_level(detail_level) + # Log information about current model and reinitialize synthesizers to ensure they use the correct model + logger.info(f"Starting report generation with model: {self.model_name or 'default model'}") + + # Reinitialize the synthesizers with the current model to ensure they use the correct model + self.report_synthesizer = get_report_synthesizer(self.model_name) + self.progressive_report_synthesizer = get_progressive_report_synthesizer(self.model_name) + + # Double-check that the synthesizers are using the correct model + logger.info(f"Report synthesizer model: {self.report_synthesizer.model_name}") + logger.info(f"Progressive report synthesizer model: {self.progressive_report_synthesizer.model_name}") + # Prepare documents for report selected_chunks = await self.prepare_documents_for_report( search_results, @@ -277,7 +309,10 @@ class ReportGenerator: # If no sub-questions or structured_query is None, use standard synthesizers elif self.detail_level.lower() == "comprehensive": # Use progressive report synthesizer for comprehensive detail level - logger.info(f"Using progressive report synthesizer for {self.detail_level} detail level") + logger.info(f"Using progressive report synthesizer for {self.detail_level} detail level with model {self.model_name}") + # Verify that the report synthesizer is using the correct model + if hasattr(self.progressive_report_synthesizer, 'model_name'): + logger.info(f"Progressive synthesizer model: {self.progressive_report_synthesizer.model_name}") report = await self.progressive_report_synthesizer.synthesize_report( selected_chunks, query, @@ -286,7 +321,10 @@ class ReportGenerator: ) else: # Use standard report synthesizer for other detail levels - logger.info(f"Using standard report synthesizer for {self.detail_level} detail level") + logger.info(f"Using standard report synthesizer for {self.detail_level} detail level with model {self.model_name}") + # Verify that the report synthesizer is using the correct model + if hasattr(self.report_synthesizer, 'model_name'): + logger.info(f"Standard synthesizer model: {self.report_synthesizer.model_name}") report = await self.report_synthesizer.synthesize_report( selected_chunks, query, diff --git a/report/report_synthesis.py b/report/report_synthesis.py index 5acba99..513e9e0 100644 --- a/report/report_synthesis.py +++ b/report/report_synthesis.py @@ -98,6 +98,10 @@ class ReportSynthesizer: """Set up the LLM provider based on the model configuration.""" provider = self.model_config.get('provider', 'groq') + # Log detailed model information for debugging + logger.info(f"Setting up report synthesizer with model: {self.model_name} (provider: {provider})") + logger.info(f"Model configuration: {self.model_config}") + try: # Get API key for the provider api_key = self.config.get_api_key(provider) @@ -105,12 +109,15 @@ class ReportSynthesizer: # Set environment variable for the provider if provider.lower() == 'google' or provider.lower() == 'gemini': os.environ["GEMINI_API_KEY"] = api_key + logger.info("Configured with GEMINI_API_KEY") elif provider.lower() == 'vertex_ai': os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = api_key + logger.info("Configured with GOOGLE_APPLICATION_CREDENTIALS") else: os.environ[f"{provider.upper()}_API_KEY"] = api_key + logger.info(f"Configured with {provider.upper()}_API_KEY") - logger.info(f"Report synthesizer initialized with model: {self.model_name} (provider: {provider})") + logger.info(f"Report synthesizer successfully initialized with model: {self.model_name} (provider: {provider})") except ValueError as e: logger.error(f"Error setting up LLM provider: {e}") @@ -229,6 +236,18 @@ class ReportSynthesizer: # Get completion parameters params = self._get_completion_params() + # Double-check that we're using the correct model + if provider == 'gemini': + params['model'] = f"gemini/{self.model_name}" + elif provider == 'groq': + params['model'] = f"groq/{self.model_name}" + + # Log the actual parameters being used for the LLM call + safe_params = params.copy() + if 'headers' in safe_params: + safe_params['headers'] = "Redacted for logging" + logger.info(f"LLM completion parameters: {safe_params}") + try: # Generate completion if stream: @@ -239,6 +258,7 @@ class ReportSynthesizer: ) return response else: + logger.info(f"Sending request to {params.get('model', 'unknown model')} with {len(formatted_messages)} messages") response = litellm.completion( messages=formatted_messages, **params @@ -304,7 +324,7 @@ class ReportSynthesizer: extraction_prompt = self._get_extraction_prompt(detail_level, query_type) total_chunks = len(chunks) - logger.info(f"Starting to process {total_chunks} document chunks") + logger.info(f"Starting to process {total_chunks} document chunks with model: {self.model_name}") # Update progress tracking state self.total_chunks = total_chunks @@ -708,7 +728,20 @@ def get_report_synthesizer(model_name: Optional[str] = None) -> ReportSynthesize global report_synthesizer if model_name and model_name != report_synthesizer.model_name: - report_synthesizer = ReportSynthesizer(model_name) + logger.info(f"Creating new report synthesizer with model: {model_name}") + try: + previous_model = report_synthesizer.model_name + report_synthesizer = ReportSynthesizer(model_name) + logger.info(f"Successfully changed model from {previous_model} to {model_name}") + except Exception as e: + logger.error(f"Error creating new report synthesizer with model {model_name}: {str(e)}") + # Fall back to the existing synthesizer + logger.info(f"Falling back to existing synthesizer with model {report_synthesizer.model_name}") + else: + if model_name: + logger.info(f"Using existing report synthesizer with model: {model_name} (already initialized)") + else: + logger.info(f"Using existing report synthesizer with default model: {report_synthesizer.model_name}") return report_synthesizer diff --git a/ui/gradio_interface.py b/ui/gradio_interface.py index 4a85e13..2add6e9 100644 --- a/ui/gradio_interface.py +++ b/ui/gradio_interface.py @@ -213,11 +213,17 @@ class GradioInterface: # Extract the actual model name from the description if selected if custom_model: # If the model is in the format "model_name (provider: model_display)" + original_custom_model = custom_model if "(" in custom_model: custom_model = custom_model.split(" (")[0] model_name = custom_model.split('/')[-1] model_suffix = f"_{model_name}" + + # Log the model selection for debugging + print(f"Selected model from UI: {original_custom_model}") + print(f"Extracted model name: {custom_model}") + print(f"Using model suffix: {model_suffix}") output_file = self.reports_dir / f"report_{timestamp}{model_suffix}.md" @@ -234,7 +240,10 @@ class GradioInterface: # If custom model is provided, use it if custom_model: - config["model"] = custom_model + # Extract the actual model name from the display name format if needed + model_name = custom_model.split(" (")[0] if " (" in custom_model else custom_model + config["model"] = model_name + print(f"Using custom model: {model_name}") # Ensure report generator is initialized if self.report_generator is None: @@ -242,8 +251,33 @@ class GradioInterface: await initialize_report_generator() self.report_generator = get_report_generator() - # This will update the report synthesizer to use the custom model + # Debug: Print initial model configuration based on detail level + detail_config = self.detail_level_manager.get_detail_level_config(detail_level) + default_model = detail_config.get("model", "unknown") + print(f"Default model for {detail_level} detail level: {default_model}") + + # First set the detail level, which will set the default model for this detail level self.report_generator.set_detail_level(detail_level) + print(f"After setting detail level, report generator model is: {self.report_generator.model_name}") + + # Then explicitly override with custom model if provided + if custom_model: + # Extract the actual model name from the display name format + # The format is "model_name (provider: model_display)" + model_name = custom_model.split(" (")[0] if " (" in custom_model else custom_model + print(f"Setting report generator to use custom model: {model_name}") + + # Look for a set_model method in the report generator + if hasattr(self.report_generator, 'set_model'): + self.report_generator.set_model(model_name) + print(f"After setting custom model, report generator model is: {self.report_generator.model_name}") + else: + print("Warning: Report generator does not have set_model method. Using alternative approach.") + # Update the config with the model as a fallback + current_config = self.report_generator.get_detail_level_config() + if current_config: + current_config["model"] = model_name + print(f"Updated config model to: {model_name}") print(f"Generating report with detail level: {detail_level}") print(f"Detail level configuration: {config}") @@ -269,6 +303,12 @@ class GradioInterface: self.search_executor.get_available_search_engines() ) + # Set the number of results to fetch per engine early so it's available throughout the function + num_results_to_fetch = config.get("initial_results_per_engine", config.get("num_results", 10)) + + # Initialize sub_question_results as an empty dict in case there are no sub-questions + sub_question_results = {} + # Check if the query was decomposed into sub-questions has_sub_questions = 'sub_questions' in structured_query and structured_query['sub_questions'] if has_sub_questions: @@ -295,10 +335,6 @@ class GradioInterface: ) progress(0.2, desc="Completed sub-question searches") - # Execute the search with the structured query - # Use initial_results_per_engine if available, otherwise fall back to num_results - num_results_to_fetch = config.get("initial_results_per_engine", config.get("num_results", 10)) - # Execute main search progress(0.3, desc="Executing main search...") search_results_dict = self.search_executor.execute_search( @@ -413,6 +449,11 @@ class GradioInterface: if chunk_title: status_message += f" ({chunk_title})" + # Add model information to status message + if hasattr(self.report_generator, 'model_name') and self.report_generator.model_name: + model_display = self.report_generator.model_name.split('/')[-1] # Extract model name without provider + status_message += f" (Using model: {model_display})" + # Update the progress status directly return status_message @@ -435,6 +476,9 @@ class GradioInterface: if len(search_results) == 0: print("WARNING: No search results found. Report generation may fail.") + # Log the current model being used by the report generator + print(f"Report generator is using model: {self.report_generator.model_name}") + # Update progress status based on detail level if detail_level.lower() == "comprehensive": self.progress_status = "Generating progressive report..." @@ -451,6 +495,10 @@ class GradioInterface: else: print("Using auto-detection for query type") + # Ensure structured_query is defined + if not locals().get('structured_query'): + structured_query = None + report = await self.report_generator.generate_report( search_results=search_results, query=query, @@ -459,7 +507,7 @@ class GradioInterface: overlap_size=config["overlap_size"], detail_level=detail_level, query_type=actual_query_type, - structured_query=structured_query if 'sub_questions' in structured_query else None + structured_query=structured_query if structured_query and 'sub_questions' in structured_query else None ) # Final progress update @@ -747,6 +795,9 @@ class GradioInterface: # Set up progress tracking progress_data = gr.Progress(track_tqdm=True) + # Debug the model selection + print(f"Model selected from UI dropdown: {model_name}") + # Call the original generate_report method result = await self.generate_report(query, detail_level, query_type, model_name, rerank, token_budget, initial_results, final_results)