Add progress tracking to report generation UI

This commit is contained in:
Steve White 2025-03-12 11:20:40 -05:00
parent 71ad21a1e7
commit 21f75c0d25
4 changed files with 104 additions and 4 deletions

View File

@ -187,6 +187,20 @@ class ReportGenerator:
return selected_chunks return selected_chunks
def set_progress_callback(self, callback):
"""
Set the progress callback for both synthesizers.
Args:
callback: Function that takes (current_progress, total, current_report) as arguments
"""
# Set the callback for both synthesizers
if hasattr(self.report_synthesizer, 'set_progress_callback'):
self.report_synthesizer.set_progress_callback(callback)
if hasattr(self.progressive_report_synthesizer, 'set_progress_callback'):
self.progressive_report_synthesizer.set_progress_callback(callback)
async def generate_report(self, async def generate_report(self,
search_results: List[Dict[str, Any]], search_results: List[Dict[str, Any]],
query: str, query: str,

View File

@ -57,6 +57,26 @@ class ReportSynthesizer:
# Flag to process <thinking> tags in model output # Flag to process <thinking> tags in model output
self.process_thinking_tags = False self.process_thinking_tags = False
# Progress tracking
self.progress_callback = None
self.total_chunks = 0
self.processed_chunk_count = 0
def set_progress_callback(self, callback):
"""
Set a callback function to report progress.
Args:
callback: Function that takes (current_progress, total, current_report) as arguments
"""
self.progress_callback = callback
def _report_progress(self, current_report=None):
"""Report progress through the callback if set."""
if self.progress_callback and self.total_chunks > 0:
progress = min(self.processed_chunk_count / self.total_chunks, 1.0)
self.progress_callback(progress, self.total_chunks, current_report)
def _setup_provider(self) -> None: def _setup_provider(self) -> None:
"""Set up the LLM provider based on the model configuration.""" """Set up the LLM provider based on the model configuration."""
@ -289,6 +309,10 @@ class ReportSynthesizer:
processed_chunk['extracted_info'] = extracted_info processed_chunk['extracted_info'] = extracted_info
batch_results.append(processed_chunk) batch_results.append(processed_chunk)
# Update progress
self.processed_chunk_count += 1
self._report_progress()
logger.info(f"Completed chunk {chunk_index}/{total_chunks} ({chunk_index/total_chunks*100:.1f}% complete)") logger.info(f"Completed chunk {chunk_index}/{total_chunks} ({chunk_index/total_chunks*100:.1f}% complete)")
except Exception as e: except Exception as e:
logger.error(f"Error processing chunk {chunk_index}/{total_chunks}: {str(e)}") logger.error(f"Error processing chunk {chunk_index}/{total_chunks}: {str(e)}")
@ -296,6 +320,10 @@ class ReportSynthesizer:
processed_chunk = chunk.copy() processed_chunk = chunk.copy()
processed_chunk['extracted_info'] = f"Error extracting information: {str(e)}" processed_chunk['extracted_info'] = f"Error extracting information: {str(e)}"
batch_results.append(processed_chunk) batch_results.append(processed_chunk)
# Update progress even for failed chunks
self.processed_chunk_count += 1
self._report_progress()
processed_chunks.extend(batch_results) processed_chunks.extend(batch_results)
@ -510,6 +538,10 @@ class ReportSynthesizer:
logger.warning("No document chunks provided for report synthesis.") logger.warning("No document chunks provided for report synthesis.")
return "No information found for the given query." return "No information found for the given query."
# Reset progress tracking
self.total_chunks = len(chunks)
self.processed_chunk_count = 0
# Verify that a template exists for the given query type and detail level # Verify that a template exists for the given query type and detail level
template = self._get_template_from_strings(query_type, detail_level) template = self._get_template_from_strings(query_type, detail_level)
if not template: if not template:
@ -545,6 +577,9 @@ class ReportSynthesizer:
# Recalculate estimated tokens # Recalculate estimated tokens
total_tokens = sum(len(chunk.get('content', '').split()) * 1.3 for chunk in chunks) total_tokens = sum(len(chunk.get('content', '').split()) * 1.3 for chunk in chunks)
logger.info(f"Reduced to {len(chunks)} chunks with estimated {total_tokens} tokens") logger.info(f"Reduced to {len(chunks)} chunks with estimated {total_tokens} tokens")
# Update total chunks for progress tracking
self.total_chunks = len(chunks)
logger.info(f"Starting map phase for {len(chunks)} document chunks with query type '{query_type}' and detail level '{detail_level}'") logger.info(f"Starting map phase for {len(chunks)} document chunks with query type '{query_type}' and detail level '{detail_level}'")
@ -578,6 +613,10 @@ class ReportSynthesizer:
logger.info(f"Starting reduce phase to synthesize report from {len(processed_chunks)} processed chunks") logger.info(f"Starting reduce phase to synthesize report from {len(processed_chunks)} processed chunks")
# Update progress status for reduce phase
if self.progress_callback:
self.progress_callback(0.9, self.total_chunks, "Synthesizing final report...")
# Reduce phase: Synthesize processed chunks into a coherent report # Reduce phase: Synthesize processed chunks into a coherent report
report = await self.reduce_processed_chunks(processed_chunks, query, query_type, detail_level) report = await self.reduce_processed_chunks(processed_chunks, query, query_type, detail_level)
@ -586,6 +625,10 @@ class ReportSynthesizer:
logger.info("Processing thinking tags in report") logger.info("Processing thinking tags in report")
report = self._process_thinking_tags(report) report = self._process_thinking_tags(report)
# Final progress update
if self.progress_callback:
self.progress_callback(1.0, self.total_chunks, report)
return report return report

View File

@ -35,8 +35,15 @@ def main():
args = parse_args() args = parse_args()
print("Starting Intelligent Research System UI...") print("Starting Intelligent Research System UI...")
# Create interface and initialize async components
import asyncio
interface = GradioInterface() interface = GradioInterface()
# Run the async initialization in the event loop
loop = asyncio.get_event_loop()
loop.run_until_complete(interface.async_init())
# Launch with the specified arguments # Launch with the specified arguments
interface.launch( interface.launch(
share=args.share, share=args.share,

View File

@ -185,7 +185,7 @@ class GradioInterface:
return markdown return markdown
async def generate_report(self, query, detail_level="standard", custom_model=None, async def generate_report(self, query, detail_level="standard", custom_model=None,
results_file=None, process_thinking_tags=False): results_file=None, process_thinking_tags=False, progress=gr.Progress()):
""" """
Generate a report for the given query. Generate a report for the given query.
@ -195,6 +195,7 @@ class GradioInterface:
custom_model: Custom model to use for report generation custom_model: Custom model to use for report generation
results_file: Path to a file containing search results results_file: Path to a file containing search results
process_thinking_tags: Whether to process thinking tags in the model output process_thinking_tags: Whether to process thinking tags in the model output
progress: Gradio progress indicator
Returns: Returns:
Path to the generated report Path to the generated report
@ -221,8 +222,15 @@ class GradioInterface:
# If custom model is provided, use it # If custom model is provided, use it
if custom_model: if custom_model:
config["model"] = custom_model config["model"] = custom_model
# This will update the report synthesizer to use the custom model
self.report_generator.set_detail_level(detail_level) # Ensure report generator is initialized
if self.report_generator is None:
print("Initializing report generator...")
await initialize_report_generator()
self.report_generator = get_report_generator()
# This will update the report synthesizer to use the custom model
self.report_generator.set_detail_level(detail_level)
print(f"Generating report with detail level: {detail_level}") print(f"Generating report with detail level: {detail_level}")
print(f"Detail level configuration: {config}") print(f"Detail level configuration: {config}")
@ -323,11 +331,36 @@ class GradioInterface:
top_n=config["num_results"] top_n=config["num_results"]
) )
# Set up progress tracking
self.progress_status = "Preparing documents..."
self.progress_value = 0
self.progress_total = 1 # Will be updated when we know the total chunks
# Define progress callback function
def progress_callback(current_progress, total_chunks, current_report):
self.progress_value = current_progress
self.progress_total = total_chunks
# Update the progress bar
progress(current_progress)
# Set the progress callback for the report generator
if hasattr(self.report_generator, 'set_progress_callback'):
self.report_generator.set_progress_callback(progress_callback)
# Generate the report # Generate the report
print(f"Generating report with {len(search_results)} search results") print(f"Generating report with {len(search_results)} search results")
if len(search_results) == 0: if len(search_results) == 0:
print("WARNING: No search results found. Report generation may fail.") print("WARNING: No search results found. Report generation may fail.")
# Update progress status based on detail level
if detail_level.lower() == "comprehensive":
self.progress_status = "Generating progressive report..."
else:
self.progress_status = "Processing document chunks..."
# Initial progress update
progress(0)
report = await self.report_generator.generate_report( report = await self.report_generator.generate_report(
search_results=search_results, search_results=search_results,
query=query, query=query,
@ -337,6 +370,9 @@ class GradioInterface:
detail_level=detail_level detail_level=detail_level
) )
# Final progress update
progress(1.0)
# Process thinking tags if requested # Process thinking tags if requested
if process_thinking_tags: if process_thinking_tags:
report = self._process_thinking_tags(report) report = self._process_thinking_tags(report)