Fix model selection in report generation to ensure UI-selected model is properly used throughout the entire report generation pipeline
This commit is contained in:
parent
76748f504e
commit
d76cd9d79b
|
@ -50,6 +50,27 @@ class ReportGenerator:
|
||||||
await initialize_database()
|
await initialize_database()
|
||||||
logger.info("Report generator initialized")
|
logger.info("Report generator initialized")
|
||||||
|
|
||||||
|
def set_model(self, model_name: str) -> None:
|
||||||
|
"""
|
||||||
|
Set the model to use for report generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the model to use
|
||||||
|
"""
|
||||||
|
if model_name and model_name != self.model_name:
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
# Update synthesizers to use this model
|
||||||
|
self.report_synthesizer = get_report_synthesizer(model_name)
|
||||||
|
self.progressive_report_synthesizer = get_progressive_report_synthesizer(model_name)
|
||||||
|
|
||||||
|
# Also update the model in the current detail level config
|
||||||
|
config = self.get_detail_level_config()
|
||||||
|
if config:
|
||||||
|
config["model"] = model_name
|
||||||
|
|
||||||
|
logger.info(f"Set model to {model_name}")
|
||||||
|
|
||||||
def set_detail_level(self, detail_level: str) -> None:
|
def set_detail_level(self, detail_level: str) -> None:
|
||||||
"""
|
"""
|
||||||
Set the detail level for report generation.
|
Set the detail level for report generation.
|
||||||
|
@ -62,14 +83,14 @@ class ReportGenerator:
|
||||||
config = self.detail_level_manager.get_detail_level_config(detail_level)
|
config = self.detail_level_manager.get_detail_level_config(detail_level)
|
||||||
self.detail_level = detail_level
|
self.detail_level = detail_level
|
||||||
|
|
||||||
# Update model if needed
|
# Update model if needed and no custom model is set
|
||||||
model = config.get("model")
|
model = config.get("model")
|
||||||
if model and model != self.model_name:
|
if model and model != self.model_name and not self.model_name:
|
||||||
self.model_name = model
|
self.model_name = model
|
||||||
self.report_synthesizer = get_report_synthesizer(model)
|
self.report_synthesizer = get_report_synthesizer(model)
|
||||||
self.progressive_report_synthesizer = get_progressive_report_synthesizer(model)
|
self.progressive_report_synthesizer = get_progressive_report_synthesizer(model)
|
||||||
|
|
||||||
logger.info(f"Detail level set to {detail_level} with model {model}")
|
logger.info(f"Detail level set to {detail_level} with model {self.model_name or model}")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f"Error setting detail level: {e}")
|
logger.error(f"Error setting detail level: {e}")
|
||||||
raise
|
raise
|
||||||
|
@ -236,6 +257,17 @@ class ReportGenerator:
|
||||||
if detail_level:
|
if detail_level:
|
||||||
self.set_detail_level(detail_level)
|
self.set_detail_level(detail_level)
|
||||||
|
|
||||||
|
# Log information about current model and reinitialize synthesizers to ensure they use the correct model
|
||||||
|
logger.info(f"Starting report generation with model: {self.model_name or 'default model'}")
|
||||||
|
|
||||||
|
# Reinitialize the synthesizers with the current model to ensure they use the correct model
|
||||||
|
self.report_synthesizer = get_report_synthesizer(self.model_name)
|
||||||
|
self.progressive_report_synthesizer = get_progressive_report_synthesizer(self.model_name)
|
||||||
|
|
||||||
|
# Double-check that the synthesizers are using the correct model
|
||||||
|
logger.info(f"Report synthesizer model: {self.report_synthesizer.model_name}")
|
||||||
|
logger.info(f"Progressive report synthesizer model: {self.progressive_report_synthesizer.model_name}")
|
||||||
|
|
||||||
# Prepare documents for report
|
# Prepare documents for report
|
||||||
selected_chunks = await self.prepare_documents_for_report(
|
selected_chunks = await self.prepare_documents_for_report(
|
||||||
search_results,
|
search_results,
|
||||||
|
@ -277,7 +309,10 @@ class ReportGenerator:
|
||||||
# If no sub-questions or structured_query is None, use standard synthesizers
|
# If no sub-questions or structured_query is None, use standard synthesizers
|
||||||
elif self.detail_level.lower() == "comprehensive":
|
elif self.detail_level.lower() == "comprehensive":
|
||||||
# Use progressive report synthesizer for comprehensive detail level
|
# Use progressive report synthesizer for comprehensive detail level
|
||||||
logger.info(f"Using progressive report synthesizer for {self.detail_level} detail level")
|
logger.info(f"Using progressive report synthesizer for {self.detail_level} detail level with model {self.model_name}")
|
||||||
|
# Verify that the report synthesizer is using the correct model
|
||||||
|
if hasattr(self.progressive_report_synthesizer, 'model_name'):
|
||||||
|
logger.info(f"Progressive synthesizer model: {self.progressive_report_synthesizer.model_name}")
|
||||||
report = await self.progressive_report_synthesizer.synthesize_report(
|
report = await self.progressive_report_synthesizer.synthesize_report(
|
||||||
selected_chunks,
|
selected_chunks,
|
||||||
query,
|
query,
|
||||||
|
@ -286,7 +321,10 @@ class ReportGenerator:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Use standard report synthesizer for other detail levels
|
# Use standard report synthesizer for other detail levels
|
||||||
logger.info(f"Using standard report synthesizer for {self.detail_level} detail level")
|
logger.info(f"Using standard report synthesizer for {self.detail_level} detail level with model {self.model_name}")
|
||||||
|
# Verify that the report synthesizer is using the correct model
|
||||||
|
if hasattr(self.report_synthesizer, 'model_name'):
|
||||||
|
logger.info(f"Standard synthesizer model: {self.report_synthesizer.model_name}")
|
||||||
report = await self.report_synthesizer.synthesize_report(
|
report = await self.report_synthesizer.synthesize_report(
|
||||||
selected_chunks,
|
selected_chunks,
|
||||||
query,
|
query,
|
||||||
|
|
|
@ -98,6 +98,10 @@ class ReportSynthesizer:
|
||||||
"""Set up the LLM provider based on the model configuration."""
|
"""Set up the LLM provider based on the model configuration."""
|
||||||
provider = self.model_config.get('provider', 'groq')
|
provider = self.model_config.get('provider', 'groq')
|
||||||
|
|
||||||
|
# Log detailed model information for debugging
|
||||||
|
logger.info(f"Setting up report synthesizer with model: {self.model_name} (provider: {provider})")
|
||||||
|
logger.info(f"Model configuration: {self.model_config}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get API key for the provider
|
# Get API key for the provider
|
||||||
api_key = self.config.get_api_key(provider)
|
api_key = self.config.get_api_key(provider)
|
||||||
|
@ -105,12 +109,15 @@ class ReportSynthesizer:
|
||||||
# Set environment variable for the provider
|
# Set environment variable for the provider
|
||||||
if provider.lower() == 'google' or provider.lower() == 'gemini':
|
if provider.lower() == 'google' or provider.lower() == 'gemini':
|
||||||
os.environ["GEMINI_API_KEY"] = api_key
|
os.environ["GEMINI_API_KEY"] = api_key
|
||||||
|
logger.info("Configured with GEMINI_API_KEY")
|
||||||
elif provider.lower() == 'vertex_ai':
|
elif provider.lower() == 'vertex_ai':
|
||||||
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = api_key
|
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = api_key
|
||||||
|
logger.info("Configured with GOOGLE_APPLICATION_CREDENTIALS")
|
||||||
else:
|
else:
|
||||||
os.environ[f"{provider.upper()}_API_KEY"] = api_key
|
os.environ[f"{provider.upper()}_API_KEY"] = api_key
|
||||||
|
logger.info(f"Configured with {provider.upper()}_API_KEY")
|
||||||
|
|
||||||
logger.info(f"Report synthesizer initialized with model: {self.model_name} (provider: {provider})")
|
logger.info(f"Report synthesizer successfully initialized with model: {self.model_name} (provider: {provider})")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f"Error setting up LLM provider: {e}")
|
logger.error(f"Error setting up LLM provider: {e}")
|
||||||
|
|
||||||
|
@ -229,6 +236,18 @@ class ReportSynthesizer:
|
||||||
# Get completion parameters
|
# Get completion parameters
|
||||||
params = self._get_completion_params()
|
params = self._get_completion_params()
|
||||||
|
|
||||||
|
# Double-check that we're using the correct model
|
||||||
|
if provider == 'gemini':
|
||||||
|
params['model'] = f"gemini/{self.model_name}"
|
||||||
|
elif provider == 'groq':
|
||||||
|
params['model'] = f"groq/{self.model_name}"
|
||||||
|
|
||||||
|
# Log the actual parameters being used for the LLM call
|
||||||
|
safe_params = params.copy()
|
||||||
|
if 'headers' in safe_params:
|
||||||
|
safe_params['headers'] = "Redacted for logging"
|
||||||
|
logger.info(f"LLM completion parameters: {safe_params}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Generate completion
|
# Generate completion
|
||||||
if stream:
|
if stream:
|
||||||
|
@ -239,6 +258,7 @@ class ReportSynthesizer:
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
else:
|
else:
|
||||||
|
logger.info(f"Sending request to {params.get('model', 'unknown model')} with {len(formatted_messages)} messages")
|
||||||
response = litellm.completion(
|
response = litellm.completion(
|
||||||
messages=formatted_messages,
|
messages=formatted_messages,
|
||||||
**params
|
**params
|
||||||
|
@ -304,7 +324,7 @@ class ReportSynthesizer:
|
||||||
extraction_prompt = self._get_extraction_prompt(detail_level, query_type)
|
extraction_prompt = self._get_extraction_prompt(detail_level, query_type)
|
||||||
|
|
||||||
total_chunks = len(chunks)
|
total_chunks = len(chunks)
|
||||||
logger.info(f"Starting to process {total_chunks} document chunks")
|
logger.info(f"Starting to process {total_chunks} document chunks with model: {self.model_name}")
|
||||||
|
|
||||||
# Update progress tracking state
|
# Update progress tracking state
|
||||||
self.total_chunks = total_chunks
|
self.total_chunks = total_chunks
|
||||||
|
@ -708,7 +728,20 @@ def get_report_synthesizer(model_name: Optional[str] = None) -> ReportSynthesize
|
||||||
global report_synthesizer
|
global report_synthesizer
|
||||||
|
|
||||||
if model_name and model_name != report_synthesizer.model_name:
|
if model_name and model_name != report_synthesizer.model_name:
|
||||||
report_synthesizer = ReportSynthesizer(model_name)
|
logger.info(f"Creating new report synthesizer with model: {model_name}")
|
||||||
|
try:
|
||||||
|
previous_model = report_synthesizer.model_name
|
||||||
|
report_synthesizer = ReportSynthesizer(model_name)
|
||||||
|
logger.info(f"Successfully changed model from {previous_model} to {model_name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating new report synthesizer with model {model_name}: {str(e)}")
|
||||||
|
# Fall back to the existing synthesizer
|
||||||
|
logger.info(f"Falling back to existing synthesizer with model {report_synthesizer.model_name}")
|
||||||
|
else:
|
||||||
|
if model_name:
|
||||||
|
logger.info(f"Using existing report synthesizer with model: {model_name} (already initialized)")
|
||||||
|
else:
|
||||||
|
logger.info(f"Using existing report synthesizer with default model: {report_synthesizer.model_name}")
|
||||||
|
|
||||||
return report_synthesizer
|
return report_synthesizer
|
||||||
|
|
||||||
|
|
|
@ -213,11 +213,17 @@ class GradioInterface:
|
||||||
# Extract the actual model name from the description if selected
|
# Extract the actual model name from the description if selected
|
||||||
if custom_model:
|
if custom_model:
|
||||||
# If the model is in the format "model_name (provider: model_display)"
|
# If the model is in the format "model_name (provider: model_display)"
|
||||||
|
original_custom_model = custom_model
|
||||||
if "(" in custom_model:
|
if "(" in custom_model:
|
||||||
custom_model = custom_model.split(" (")[0]
|
custom_model = custom_model.split(" (")[0]
|
||||||
|
|
||||||
model_name = custom_model.split('/')[-1]
|
model_name = custom_model.split('/')[-1]
|
||||||
model_suffix = f"_{model_name}"
|
model_suffix = f"_{model_name}"
|
||||||
|
|
||||||
|
# Log the model selection for debugging
|
||||||
|
print(f"Selected model from UI: {original_custom_model}")
|
||||||
|
print(f"Extracted model name: {custom_model}")
|
||||||
|
print(f"Using model suffix: {model_suffix}")
|
||||||
|
|
||||||
output_file = self.reports_dir / f"report_{timestamp}{model_suffix}.md"
|
output_file = self.reports_dir / f"report_{timestamp}{model_suffix}.md"
|
||||||
|
|
||||||
|
@ -234,7 +240,10 @@ class GradioInterface:
|
||||||
|
|
||||||
# If custom model is provided, use it
|
# If custom model is provided, use it
|
||||||
if custom_model:
|
if custom_model:
|
||||||
config["model"] = custom_model
|
# Extract the actual model name from the display name format if needed
|
||||||
|
model_name = custom_model.split(" (")[0] if " (" in custom_model else custom_model
|
||||||
|
config["model"] = model_name
|
||||||
|
print(f"Using custom model: {model_name}")
|
||||||
|
|
||||||
# Ensure report generator is initialized
|
# Ensure report generator is initialized
|
||||||
if self.report_generator is None:
|
if self.report_generator is None:
|
||||||
|
@ -242,8 +251,33 @@ class GradioInterface:
|
||||||
await initialize_report_generator()
|
await initialize_report_generator()
|
||||||
self.report_generator = get_report_generator()
|
self.report_generator = get_report_generator()
|
||||||
|
|
||||||
# This will update the report synthesizer to use the custom model
|
# Debug: Print initial model configuration based on detail level
|
||||||
|
detail_config = self.detail_level_manager.get_detail_level_config(detail_level)
|
||||||
|
default_model = detail_config.get("model", "unknown")
|
||||||
|
print(f"Default model for {detail_level} detail level: {default_model}")
|
||||||
|
|
||||||
|
# First set the detail level, which will set the default model for this detail level
|
||||||
self.report_generator.set_detail_level(detail_level)
|
self.report_generator.set_detail_level(detail_level)
|
||||||
|
print(f"After setting detail level, report generator model is: {self.report_generator.model_name}")
|
||||||
|
|
||||||
|
# Then explicitly override with custom model if provided
|
||||||
|
if custom_model:
|
||||||
|
# Extract the actual model name from the display name format
|
||||||
|
# The format is "model_name (provider: model_display)"
|
||||||
|
model_name = custom_model.split(" (")[0] if " (" in custom_model else custom_model
|
||||||
|
print(f"Setting report generator to use custom model: {model_name}")
|
||||||
|
|
||||||
|
# Look for a set_model method in the report generator
|
||||||
|
if hasattr(self.report_generator, 'set_model'):
|
||||||
|
self.report_generator.set_model(model_name)
|
||||||
|
print(f"After setting custom model, report generator model is: {self.report_generator.model_name}")
|
||||||
|
else:
|
||||||
|
print("Warning: Report generator does not have set_model method. Using alternative approach.")
|
||||||
|
# Update the config with the model as a fallback
|
||||||
|
current_config = self.report_generator.get_detail_level_config()
|
||||||
|
if current_config:
|
||||||
|
current_config["model"] = model_name
|
||||||
|
print(f"Updated config model to: {model_name}")
|
||||||
|
|
||||||
print(f"Generating report with detail level: {detail_level}")
|
print(f"Generating report with detail level: {detail_level}")
|
||||||
print(f"Detail level configuration: {config}")
|
print(f"Detail level configuration: {config}")
|
||||||
|
@ -269,6 +303,12 @@ class GradioInterface:
|
||||||
self.search_executor.get_available_search_engines()
|
self.search_executor.get_available_search_engines()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Set the number of results to fetch per engine early so it's available throughout the function
|
||||||
|
num_results_to_fetch = config.get("initial_results_per_engine", config.get("num_results", 10))
|
||||||
|
|
||||||
|
# Initialize sub_question_results as an empty dict in case there are no sub-questions
|
||||||
|
sub_question_results = {}
|
||||||
|
|
||||||
# Check if the query was decomposed into sub-questions
|
# Check if the query was decomposed into sub-questions
|
||||||
has_sub_questions = 'sub_questions' in structured_query and structured_query['sub_questions']
|
has_sub_questions = 'sub_questions' in structured_query and structured_query['sub_questions']
|
||||||
if has_sub_questions:
|
if has_sub_questions:
|
||||||
|
@ -295,10 +335,6 @@ class GradioInterface:
|
||||||
)
|
)
|
||||||
progress(0.2, desc="Completed sub-question searches")
|
progress(0.2, desc="Completed sub-question searches")
|
||||||
|
|
||||||
# Execute the search with the structured query
|
|
||||||
# Use initial_results_per_engine if available, otherwise fall back to num_results
|
|
||||||
num_results_to_fetch = config.get("initial_results_per_engine", config.get("num_results", 10))
|
|
||||||
|
|
||||||
# Execute main search
|
# Execute main search
|
||||||
progress(0.3, desc="Executing main search...")
|
progress(0.3, desc="Executing main search...")
|
||||||
search_results_dict = self.search_executor.execute_search(
|
search_results_dict = self.search_executor.execute_search(
|
||||||
|
@ -413,6 +449,11 @@ class GradioInterface:
|
||||||
if chunk_title:
|
if chunk_title:
|
||||||
status_message += f" ({chunk_title})"
|
status_message += f" ({chunk_title})"
|
||||||
|
|
||||||
|
# Add model information to status message
|
||||||
|
if hasattr(self.report_generator, 'model_name') and self.report_generator.model_name:
|
||||||
|
model_display = self.report_generator.model_name.split('/')[-1] # Extract model name without provider
|
||||||
|
status_message += f" (Using model: {model_display})"
|
||||||
|
|
||||||
# Update the progress status directly
|
# Update the progress status directly
|
||||||
return status_message
|
return status_message
|
||||||
|
|
||||||
|
@ -435,6 +476,9 @@ class GradioInterface:
|
||||||
if len(search_results) == 0:
|
if len(search_results) == 0:
|
||||||
print("WARNING: No search results found. Report generation may fail.")
|
print("WARNING: No search results found. Report generation may fail.")
|
||||||
|
|
||||||
|
# Log the current model being used by the report generator
|
||||||
|
print(f"Report generator is using model: {self.report_generator.model_name}")
|
||||||
|
|
||||||
# Update progress status based on detail level
|
# Update progress status based on detail level
|
||||||
if detail_level.lower() == "comprehensive":
|
if detail_level.lower() == "comprehensive":
|
||||||
self.progress_status = "Generating progressive report..."
|
self.progress_status = "Generating progressive report..."
|
||||||
|
@ -451,6 +495,10 @@ class GradioInterface:
|
||||||
else:
|
else:
|
||||||
print("Using auto-detection for query type")
|
print("Using auto-detection for query type")
|
||||||
|
|
||||||
|
# Ensure structured_query is defined
|
||||||
|
if not locals().get('structured_query'):
|
||||||
|
structured_query = None
|
||||||
|
|
||||||
report = await self.report_generator.generate_report(
|
report = await self.report_generator.generate_report(
|
||||||
search_results=search_results,
|
search_results=search_results,
|
||||||
query=query,
|
query=query,
|
||||||
|
@ -459,7 +507,7 @@ class GradioInterface:
|
||||||
overlap_size=config["overlap_size"],
|
overlap_size=config["overlap_size"],
|
||||||
detail_level=detail_level,
|
detail_level=detail_level,
|
||||||
query_type=actual_query_type,
|
query_type=actual_query_type,
|
||||||
structured_query=structured_query if 'sub_questions' in structured_query else None
|
structured_query=structured_query if structured_query and 'sub_questions' in structured_query else None
|
||||||
)
|
)
|
||||||
|
|
||||||
# Final progress update
|
# Final progress update
|
||||||
|
@ -747,6 +795,9 @@ class GradioInterface:
|
||||||
# Set up progress tracking
|
# Set up progress tracking
|
||||||
progress_data = gr.Progress(track_tqdm=True)
|
progress_data = gr.Progress(track_tqdm=True)
|
||||||
|
|
||||||
|
# Debug the model selection
|
||||||
|
print(f"Model selected from UI dropdown: {model_name}")
|
||||||
|
|
||||||
# Call the original generate_report method
|
# Call the original generate_report method
|
||||||
result = await self.generate_report(query, detail_level, query_type, model_name, rerank, token_budget, initial_results, final_results)
|
result = await self.generate_report(query, detail_level, query_type, model_name, rerank, token_budget, initial_results, final_results)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue