Fix model selection in report generation to ensure UI-selected model is properly used throughout the entire report generation pipeline

This commit is contained in:
Steve White 2025-03-18 17:31:40 -05:00
parent 76748f504e
commit d76cd9d79b
3 changed files with 137 additions and 15 deletions

View File

@ -50,6 +50,27 @@ class ReportGenerator:
await initialize_database() await initialize_database()
logger.info("Report generator initialized") logger.info("Report generator initialized")
def set_model(self, model_name: str) -> None:
"""
Set the model to use for report generation.
Args:
model_name: Name of the model to use
"""
if model_name and model_name != self.model_name:
self.model_name = model_name
# Update synthesizers to use this model
self.report_synthesizer = get_report_synthesizer(model_name)
self.progressive_report_synthesizer = get_progressive_report_synthesizer(model_name)
# Also update the model in the current detail level config
config = self.get_detail_level_config()
if config:
config["model"] = model_name
logger.info(f"Set model to {model_name}")
def set_detail_level(self, detail_level: str) -> None: def set_detail_level(self, detail_level: str) -> None:
""" """
Set the detail level for report generation. Set the detail level for report generation.
@ -62,14 +83,14 @@ class ReportGenerator:
config = self.detail_level_manager.get_detail_level_config(detail_level) config = self.detail_level_manager.get_detail_level_config(detail_level)
self.detail_level = detail_level self.detail_level = detail_level
# Update model if needed # Update model if needed and no custom model is set
model = config.get("model") model = config.get("model")
if model and model != self.model_name: if model and model != self.model_name and not self.model_name:
self.model_name = model self.model_name = model
self.report_synthesizer = get_report_synthesizer(model) self.report_synthesizer = get_report_synthesizer(model)
self.progressive_report_synthesizer = get_progressive_report_synthesizer(model) self.progressive_report_synthesizer = get_progressive_report_synthesizer(model)
logger.info(f"Detail level set to {detail_level} with model {model}") logger.info(f"Detail level set to {detail_level} with model {self.model_name or model}")
except ValueError as e: except ValueError as e:
logger.error(f"Error setting detail level: {e}") logger.error(f"Error setting detail level: {e}")
raise raise
@ -236,6 +257,17 @@ class ReportGenerator:
if detail_level: if detail_level:
self.set_detail_level(detail_level) self.set_detail_level(detail_level)
# Log information about current model and reinitialize synthesizers to ensure they use the correct model
logger.info(f"Starting report generation with model: {self.model_name or 'default model'}")
# Reinitialize the synthesizers with the current model to ensure they use the correct model
self.report_synthesizer = get_report_synthesizer(self.model_name)
self.progressive_report_synthesizer = get_progressive_report_synthesizer(self.model_name)
# Double-check that the synthesizers are using the correct model
logger.info(f"Report synthesizer model: {self.report_synthesizer.model_name}")
logger.info(f"Progressive report synthesizer model: {self.progressive_report_synthesizer.model_name}")
# Prepare documents for report # Prepare documents for report
selected_chunks = await self.prepare_documents_for_report( selected_chunks = await self.prepare_documents_for_report(
search_results, search_results,
@ -277,7 +309,10 @@ class ReportGenerator:
# If no sub-questions or structured_query is None, use standard synthesizers # If no sub-questions or structured_query is None, use standard synthesizers
elif self.detail_level.lower() == "comprehensive": elif self.detail_level.lower() == "comprehensive":
# Use progressive report synthesizer for comprehensive detail level # Use progressive report synthesizer for comprehensive detail level
logger.info(f"Using progressive report synthesizer for {self.detail_level} detail level") logger.info(f"Using progressive report synthesizer for {self.detail_level} detail level with model {self.model_name}")
# Verify that the report synthesizer is using the correct model
if hasattr(self.progressive_report_synthesizer, 'model_name'):
logger.info(f"Progressive synthesizer model: {self.progressive_report_synthesizer.model_name}")
report = await self.progressive_report_synthesizer.synthesize_report( report = await self.progressive_report_synthesizer.synthesize_report(
selected_chunks, selected_chunks,
query, query,
@ -286,7 +321,10 @@ class ReportGenerator:
) )
else: else:
# Use standard report synthesizer for other detail levels # Use standard report synthesizer for other detail levels
logger.info(f"Using standard report synthesizer for {self.detail_level} detail level") logger.info(f"Using standard report synthesizer for {self.detail_level} detail level with model {self.model_name}")
# Verify that the report synthesizer is using the correct model
if hasattr(self.report_synthesizer, 'model_name'):
logger.info(f"Standard synthesizer model: {self.report_synthesizer.model_name}")
report = await self.report_synthesizer.synthesize_report( report = await self.report_synthesizer.synthesize_report(
selected_chunks, selected_chunks,
query, query,

View File

@ -98,6 +98,10 @@ class ReportSynthesizer:
"""Set up the LLM provider based on the model configuration.""" """Set up the LLM provider based on the model configuration."""
provider = self.model_config.get('provider', 'groq') provider = self.model_config.get('provider', 'groq')
# Log detailed model information for debugging
logger.info(f"Setting up report synthesizer with model: {self.model_name} (provider: {provider})")
logger.info(f"Model configuration: {self.model_config}")
try: try:
# Get API key for the provider # Get API key for the provider
api_key = self.config.get_api_key(provider) api_key = self.config.get_api_key(provider)
@ -105,12 +109,15 @@ class ReportSynthesizer:
# Set environment variable for the provider # Set environment variable for the provider
if provider.lower() == 'google' or provider.lower() == 'gemini': if provider.lower() == 'google' or provider.lower() == 'gemini':
os.environ["GEMINI_API_KEY"] = api_key os.environ["GEMINI_API_KEY"] = api_key
logger.info("Configured with GEMINI_API_KEY")
elif provider.lower() == 'vertex_ai': elif provider.lower() == 'vertex_ai':
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = api_key os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = api_key
logger.info("Configured with GOOGLE_APPLICATION_CREDENTIALS")
else: else:
os.environ[f"{provider.upper()}_API_KEY"] = api_key os.environ[f"{provider.upper()}_API_KEY"] = api_key
logger.info(f"Configured with {provider.upper()}_API_KEY")
logger.info(f"Report synthesizer initialized with model: {self.model_name} (provider: {provider})") logger.info(f"Report synthesizer successfully initialized with model: {self.model_name} (provider: {provider})")
except ValueError as e: except ValueError as e:
logger.error(f"Error setting up LLM provider: {e}") logger.error(f"Error setting up LLM provider: {e}")
@ -229,6 +236,18 @@ class ReportSynthesizer:
# Get completion parameters # Get completion parameters
params = self._get_completion_params() params = self._get_completion_params()
# Double-check that we're using the correct model
if provider == 'gemini':
params['model'] = f"gemini/{self.model_name}"
elif provider == 'groq':
params['model'] = f"groq/{self.model_name}"
# Log the actual parameters being used for the LLM call
safe_params = params.copy()
if 'headers' in safe_params:
safe_params['headers'] = "Redacted for logging"
logger.info(f"LLM completion parameters: {safe_params}")
try: try:
# Generate completion # Generate completion
if stream: if stream:
@ -239,6 +258,7 @@ class ReportSynthesizer:
) )
return response return response
else: else:
logger.info(f"Sending request to {params.get('model', 'unknown model')} with {len(formatted_messages)} messages")
response = litellm.completion( response = litellm.completion(
messages=formatted_messages, messages=formatted_messages,
**params **params
@ -304,7 +324,7 @@ class ReportSynthesizer:
extraction_prompt = self._get_extraction_prompt(detail_level, query_type) extraction_prompt = self._get_extraction_prompt(detail_level, query_type)
total_chunks = len(chunks) total_chunks = len(chunks)
logger.info(f"Starting to process {total_chunks} document chunks") logger.info(f"Starting to process {total_chunks} document chunks with model: {self.model_name}")
# Update progress tracking state # Update progress tracking state
self.total_chunks = total_chunks self.total_chunks = total_chunks
@ -708,7 +728,20 @@ def get_report_synthesizer(model_name: Optional[str] = None) -> ReportSynthesize
global report_synthesizer global report_synthesizer
if model_name and model_name != report_synthesizer.model_name: if model_name and model_name != report_synthesizer.model_name:
report_synthesizer = ReportSynthesizer(model_name) logger.info(f"Creating new report synthesizer with model: {model_name}")
try:
previous_model = report_synthesizer.model_name
report_synthesizer = ReportSynthesizer(model_name)
logger.info(f"Successfully changed model from {previous_model} to {model_name}")
except Exception as e:
logger.error(f"Error creating new report synthesizer with model {model_name}: {str(e)}")
# Fall back to the existing synthesizer
logger.info(f"Falling back to existing synthesizer with model {report_synthesizer.model_name}")
else:
if model_name:
logger.info(f"Using existing report synthesizer with model: {model_name} (already initialized)")
else:
logger.info(f"Using existing report synthesizer with default model: {report_synthesizer.model_name}")
return report_synthesizer return report_synthesizer

View File

@ -213,11 +213,17 @@ class GradioInterface:
# Extract the actual model name from the description if selected # Extract the actual model name from the description if selected
if custom_model: if custom_model:
# If the model is in the format "model_name (provider: model_display)" # If the model is in the format "model_name (provider: model_display)"
original_custom_model = custom_model
if "(" in custom_model: if "(" in custom_model:
custom_model = custom_model.split(" (")[0] custom_model = custom_model.split(" (")[0]
model_name = custom_model.split('/')[-1] model_name = custom_model.split('/')[-1]
model_suffix = f"_{model_name}" model_suffix = f"_{model_name}"
# Log the model selection for debugging
print(f"Selected model from UI: {original_custom_model}")
print(f"Extracted model name: {custom_model}")
print(f"Using model suffix: {model_suffix}")
output_file = self.reports_dir / f"report_{timestamp}{model_suffix}.md" output_file = self.reports_dir / f"report_{timestamp}{model_suffix}.md"
@ -234,7 +240,10 @@ class GradioInterface:
# If custom model is provided, use it # If custom model is provided, use it
if custom_model: if custom_model:
config["model"] = custom_model # Extract the actual model name from the display name format if needed
model_name = custom_model.split(" (")[0] if " (" in custom_model else custom_model
config["model"] = model_name
print(f"Using custom model: {model_name}")
# Ensure report generator is initialized # Ensure report generator is initialized
if self.report_generator is None: if self.report_generator is None:
@ -242,8 +251,33 @@ class GradioInterface:
await initialize_report_generator() await initialize_report_generator()
self.report_generator = get_report_generator() self.report_generator = get_report_generator()
# This will update the report synthesizer to use the custom model # Debug: Print initial model configuration based on detail level
detail_config = self.detail_level_manager.get_detail_level_config(detail_level)
default_model = detail_config.get("model", "unknown")
print(f"Default model for {detail_level} detail level: {default_model}")
# First set the detail level, which will set the default model for this detail level
self.report_generator.set_detail_level(detail_level) self.report_generator.set_detail_level(detail_level)
print(f"After setting detail level, report generator model is: {self.report_generator.model_name}")
# Then explicitly override with custom model if provided
if custom_model:
# Extract the actual model name from the display name format
# The format is "model_name (provider: model_display)"
model_name = custom_model.split(" (")[0] if " (" in custom_model else custom_model
print(f"Setting report generator to use custom model: {model_name}")
# Look for a set_model method in the report generator
if hasattr(self.report_generator, 'set_model'):
self.report_generator.set_model(model_name)
print(f"After setting custom model, report generator model is: {self.report_generator.model_name}")
else:
print("Warning: Report generator does not have set_model method. Using alternative approach.")
# Update the config with the model as a fallback
current_config = self.report_generator.get_detail_level_config()
if current_config:
current_config["model"] = model_name
print(f"Updated config model to: {model_name}")
print(f"Generating report with detail level: {detail_level}") print(f"Generating report with detail level: {detail_level}")
print(f"Detail level configuration: {config}") print(f"Detail level configuration: {config}")
@ -269,6 +303,12 @@ class GradioInterface:
self.search_executor.get_available_search_engines() self.search_executor.get_available_search_engines()
) )
# Set the number of results to fetch per engine early so it's available throughout the function
num_results_to_fetch = config.get("initial_results_per_engine", config.get("num_results", 10))
# Initialize sub_question_results as an empty dict in case there are no sub-questions
sub_question_results = {}
# Check if the query was decomposed into sub-questions # Check if the query was decomposed into sub-questions
has_sub_questions = 'sub_questions' in structured_query and structured_query['sub_questions'] has_sub_questions = 'sub_questions' in structured_query and structured_query['sub_questions']
if has_sub_questions: if has_sub_questions:
@ -295,10 +335,6 @@ class GradioInterface:
) )
progress(0.2, desc="Completed sub-question searches") progress(0.2, desc="Completed sub-question searches")
# Execute the search with the structured query
# Use initial_results_per_engine if available, otherwise fall back to num_results
num_results_to_fetch = config.get("initial_results_per_engine", config.get("num_results", 10))
# Execute main search # Execute main search
progress(0.3, desc="Executing main search...") progress(0.3, desc="Executing main search...")
search_results_dict = self.search_executor.execute_search( search_results_dict = self.search_executor.execute_search(
@ -413,6 +449,11 @@ class GradioInterface:
if chunk_title: if chunk_title:
status_message += f" ({chunk_title})" status_message += f" ({chunk_title})"
# Add model information to status message
if hasattr(self.report_generator, 'model_name') and self.report_generator.model_name:
model_display = self.report_generator.model_name.split('/')[-1] # Extract model name without provider
status_message += f" (Using model: {model_display})"
# Update the progress status directly # Update the progress status directly
return status_message return status_message
@ -435,6 +476,9 @@ class GradioInterface:
if len(search_results) == 0: if len(search_results) == 0:
print("WARNING: No search results found. Report generation may fail.") print("WARNING: No search results found. Report generation may fail.")
# Log the current model being used by the report generator
print(f"Report generator is using model: {self.report_generator.model_name}")
# Update progress status based on detail level # Update progress status based on detail level
if detail_level.lower() == "comprehensive": if detail_level.lower() == "comprehensive":
self.progress_status = "Generating progressive report..." self.progress_status = "Generating progressive report..."
@ -451,6 +495,10 @@ class GradioInterface:
else: else:
print("Using auto-detection for query type") print("Using auto-detection for query type")
# Ensure structured_query is defined
if not locals().get('structured_query'):
structured_query = None
report = await self.report_generator.generate_report( report = await self.report_generator.generate_report(
search_results=search_results, search_results=search_results,
query=query, query=query,
@ -459,7 +507,7 @@ class GradioInterface:
overlap_size=config["overlap_size"], overlap_size=config["overlap_size"],
detail_level=detail_level, detail_level=detail_level,
query_type=actual_query_type, query_type=actual_query_type,
structured_query=structured_query if 'sub_questions' in structured_query else None structured_query=structured_query if structured_query and 'sub_questions' in structured_query else None
) )
# Final progress update # Final progress update
@ -747,6 +795,9 @@ class GradioInterface:
# Set up progress tracking # Set up progress tracking
progress_data = gr.Progress(track_tqdm=True) progress_data = gr.Progress(track_tqdm=True)
# Debug the model selection
print(f"Model selected from UI dropdown: {model_name}")
# Call the original generate_report method # Call the original generate_report method
result = await self.generate_report(query, detail_level, query_type, model_name, rerank, token_budget, initial_results, final_results) result = await self.generate_report(query, detail_level, query_type, model_name, rerank, token_budget, initial_results, final_results)