Add progress tracking to report generation UI
This commit is contained in:
parent
71ad21a1e7
commit
21f75c0d25
|
@ -187,6 +187,20 @@ class ReportGenerator:
|
|||
|
||||
return selected_chunks
|
||||
|
||||
def set_progress_callback(self, callback):
|
||||
"""
|
||||
Set the progress callback for both synthesizers.
|
||||
|
||||
Args:
|
||||
callback: Function that takes (current_progress, total, current_report) as arguments
|
||||
"""
|
||||
# Set the callback for both synthesizers
|
||||
if hasattr(self.report_synthesizer, 'set_progress_callback'):
|
||||
self.report_synthesizer.set_progress_callback(callback)
|
||||
|
||||
if hasattr(self.progressive_report_synthesizer, 'set_progress_callback'):
|
||||
self.progressive_report_synthesizer.set_progress_callback(callback)
|
||||
|
||||
async def generate_report(self,
|
||||
search_results: List[Dict[str, Any]],
|
||||
query: str,
|
||||
|
|
|
@ -58,6 +58,26 @@ class ReportSynthesizer:
|
|||
# Flag to process <thinking> tags in model output
|
||||
self.process_thinking_tags = False
|
||||
|
||||
# Progress tracking
|
||||
self.progress_callback = None
|
||||
self.total_chunks = 0
|
||||
self.processed_chunk_count = 0
|
||||
|
||||
def set_progress_callback(self, callback):
|
||||
"""
|
||||
Set a callback function to report progress.
|
||||
|
||||
Args:
|
||||
callback: Function that takes (current_progress, total, current_report) as arguments
|
||||
"""
|
||||
self.progress_callback = callback
|
||||
|
||||
def _report_progress(self, current_report=None):
|
||||
"""Report progress through the callback if set."""
|
||||
if self.progress_callback and self.total_chunks > 0:
|
||||
progress = min(self.processed_chunk_count / self.total_chunks, 1.0)
|
||||
self.progress_callback(progress, self.total_chunks, current_report)
|
||||
|
||||
def _setup_provider(self) -> None:
|
||||
"""Set up the LLM provider based on the model configuration."""
|
||||
provider = self.model_config.get('provider', 'groq')
|
||||
|
@ -289,6 +309,10 @@ class ReportSynthesizer:
|
|||
processed_chunk['extracted_info'] = extracted_info
|
||||
batch_results.append(processed_chunk)
|
||||
|
||||
# Update progress
|
||||
self.processed_chunk_count += 1
|
||||
self._report_progress()
|
||||
|
||||
logger.info(f"Completed chunk {chunk_index}/{total_chunks} ({chunk_index/total_chunks*100:.1f}% complete)")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing chunk {chunk_index}/{total_chunks}: {str(e)}")
|
||||
|
@ -297,6 +321,10 @@ class ReportSynthesizer:
|
|||
processed_chunk['extracted_info'] = f"Error extracting information: {str(e)}"
|
||||
batch_results.append(processed_chunk)
|
||||
|
||||
# Update progress even for failed chunks
|
||||
self.processed_chunk_count += 1
|
||||
self._report_progress()
|
||||
|
||||
processed_chunks.extend(batch_results)
|
||||
|
||||
# Add a small delay between batches to avoid rate limiting
|
||||
|
@ -510,6 +538,10 @@ class ReportSynthesizer:
|
|||
logger.warning("No document chunks provided for report synthesis.")
|
||||
return "No information found for the given query."
|
||||
|
||||
# Reset progress tracking
|
||||
self.total_chunks = len(chunks)
|
||||
self.processed_chunk_count = 0
|
||||
|
||||
# Verify that a template exists for the given query type and detail level
|
||||
template = self._get_template_from_strings(query_type, detail_level)
|
||||
if not template:
|
||||
|
@ -546,6 +578,9 @@ class ReportSynthesizer:
|
|||
total_tokens = sum(len(chunk.get('content', '').split()) * 1.3 for chunk in chunks)
|
||||
logger.info(f"Reduced to {len(chunks)} chunks with estimated {total_tokens} tokens")
|
||||
|
||||
# Update total chunks for progress tracking
|
||||
self.total_chunks = len(chunks)
|
||||
|
||||
logger.info(f"Starting map phase for {len(chunks)} document chunks with query type '{query_type}' and detail level '{detail_level}'")
|
||||
|
||||
# Process chunks in batches to avoid hitting payload limits
|
||||
|
@ -578,6 +613,10 @@ class ReportSynthesizer:
|
|||
|
||||
logger.info(f"Starting reduce phase to synthesize report from {len(processed_chunks)} processed chunks")
|
||||
|
||||
# Update progress status for reduce phase
|
||||
if self.progress_callback:
|
||||
self.progress_callback(0.9, self.total_chunks, "Synthesizing final report...")
|
||||
|
||||
# Reduce phase: Synthesize processed chunks into a coherent report
|
||||
report = await self.reduce_processed_chunks(processed_chunks, query, query_type, detail_level)
|
||||
|
||||
|
@ -586,6 +625,10 @@ class ReportSynthesizer:
|
|||
logger.info("Processing thinking tags in report")
|
||||
report = self._process_thinking_tags(report)
|
||||
|
||||
# Final progress update
|
||||
if self.progress_callback:
|
||||
self.progress_callback(1.0, self.total_chunks, report)
|
||||
|
||||
return report
|
||||
|
||||
|
||||
|
|
|
@ -35,8 +35,15 @@ def main():
|
|||
args = parse_args()
|
||||
|
||||
print("Starting Intelligent Research System UI...")
|
||||
|
||||
# Create interface and initialize async components
|
||||
import asyncio
|
||||
interface = GradioInterface()
|
||||
|
||||
# Run the async initialization in the event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(interface.async_init())
|
||||
|
||||
# Launch with the specified arguments
|
||||
interface.launch(
|
||||
share=args.share,
|
||||
|
|
|
@ -185,7 +185,7 @@ class GradioInterface:
|
|||
return markdown
|
||||
|
||||
async def generate_report(self, query, detail_level="standard", custom_model=None,
|
||||
results_file=None, process_thinking_tags=False):
|
||||
results_file=None, process_thinking_tags=False, progress=gr.Progress()):
|
||||
"""
|
||||
Generate a report for the given query.
|
||||
|
||||
|
@ -195,6 +195,7 @@ class GradioInterface:
|
|||
custom_model: Custom model to use for report generation
|
||||
results_file: Path to a file containing search results
|
||||
process_thinking_tags: Whether to process thinking tags in the model output
|
||||
progress: Gradio progress indicator
|
||||
|
||||
Returns:
|
||||
Path to the generated report
|
||||
|
@ -221,8 +222,15 @@ class GradioInterface:
|
|||
# If custom model is provided, use it
|
||||
if custom_model:
|
||||
config["model"] = custom_model
|
||||
# This will update the report synthesizer to use the custom model
|
||||
self.report_generator.set_detail_level(detail_level)
|
||||
|
||||
# Ensure report generator is initialized
|
||||
if self.report_generator is None:
|
||||
print("Initializing report generator...")
|
||||
await initialize_report_generator()
|
||||
self.report_generator = get_report_generator()
|
||||
|
||||
# This will update the report synthesizer to use the custom model
|
||||
self.report_generator.set_detail_level(detail_level)
|
||||
|
||||
print(f"Generating report with detail level: {detail_level}")
|
||||
print(f"Detail level configuration: {config}")
|
||||
|
@ -323,11 +331,36 @@ class GradioInterface:
|
|||
top_n=config["num_results"]
|
||||
)
|
||||
|
||||
# Set up progress tracking
|
||||
self.progress_status = "Preparing documents..."
|
||||
self.progress_value = 0
|
||||
self.progress_total = 1 # Will be updated when we know the total chunks
|
||||
|
||||
# Define progress callback function
|
||||
def progress_callback(current_progress, total_chunks, current_report):
|
||||
self.progress_value = current_progress
|
||||
self.progress_total = total_chunks
|
||||
# Update the progress bar
|
||||
progress(current_progress)
|
||||
|
||||
# Set the progress callback for the report generator
|
||||
if hasattr(self.report_generator, 'set_progress_callback'):
|
||||
self.report_generator.set_progress_callback(progress_callback)
|
||||
|
||||
# Generate the report
|
||||
print(f"Generating report with {len(search_results)} search results")
|
||||
if len(search_results) == 0:
|
||||
print("WARNING: No search results found. Report generation may fail.")
|
||||
|
||||
# Update progress status based on detail level
|
||||
if detail_level.lower() == "comprehensive":
|
||||
self.progress_status = "Generating progressive report..."
|
||||
else:
|
||||
self.progress_status = "Processing document chunks..."
|
||||
|
||||
# Initial progress update
|
||||
progress(0)
|
||||
|
||||
report = await self.report_generator.generate_report(
|
||||
search_results=search_results,
|
||||
query=query,
|
||||
|
@ -337,6 +370,9 @@ class GradioInterface:
|
|||
detail_level=detail_level
|
||||
)
|
||||
|
||||
# Final progress update
|
||||
progress(1.0)
|
||||
|
||||
# Process thinking tags if requested
|
||||
if process_thinking_tags:
|
||||
report = self._process_thinking_tags(report)
|
||||
|
|
Loading…
Reference in New Issue