diff --git a/report/report_generator.py b/report/report_generator.py index 2e8efad..36e2029 100644 --- a/report/report_generator.py +++ b/report/report_generator.py @@ -207,7 +207,8 @@ class ReportGenerator: token_budget: Optional[int] = None, chunk_size: Optional[int] = None, overlap_size: Optional[int] = None, - detail_level: Optional[str] = None) -> str: + detail_level: Optional[str] = None, + query_type: Optional[str] = None) -> str: """ Generate a report from search results. @@ -234,6 +235,12 @@ class ReportGenerator: overlap_size ) + # Log query type information + if query_type: + logger.info(f"Using specified query type: {query_type}") + else: + logger.info("Using automatic query type detection") + # Choose the appropriate synthesizer based on detail level if self.detail_level.lower() == "comprehensive": # Use progressive report synthesizer for comprehensive detail level @@ -241,6 +248,7 @@ class ReportGenerator: report = await self.progressive_report_synthesizer.synthesize_report( selected_chunks, query, + query_type=query_type, detail_level=self.detail_level ) else: @@ -249,6 +257,7 @@ class ReportGenerator: report = await self.report_synthesizer.synthesize_report( selected_chunks, query, + query_type=query_type, detail_level=self.detail_level ) diff --git a/ui/gradio_interface.py b/ui/gradio_interface.py index 3de40ef..78a1dbb 100644 --- a/ui/gradio_interface.py +++ b/ui/gradio_interface.py @@ -184,7 +184,7 @@ class GradioInterface: return markdown - async def generate_report(self, query, detail_level="standard", custom_model=None, + async def generate_report(self, query, detail_level="standard", query_type="auto-detect", custom_model=None, results_file=None, process_thinking_tags=False, progress=gr.Progress()): """ Generate a report for the given query. @@ -361,13 +361,22 @@ class GradioInterface: # Initial progress update progress(0) + # Handle query_type parameter + actual_query_type = None + if query_type != "auto-detect": + actual_query_type = query_type + print(f"Using user-selected query type: {actual_query_type}") + else: + print("Using auto-detection for query type") + report = await self.report_generator.generate_report( search_results=search_results, query=query, token_budget=config["token_budget"], chunk_size=config["chunk_size"], overlap_size=config["overlap_size"], - detail_level=detail_level + detail_level=detail_level, + query_type=actual_query_type ) # Final progress update @@ -538,6 +547,12 @@ class GradioInterface: label="Detail Level", info="Controls the depth and breadth of the report" ) + report_query_type = gr.Dropdown( + choices=["auto-detect", "factual", "exploratory", "comparative"], + value="auto-detect", + label="Query Type", + info="Type of query determines the report structure" + ) model_descriptions = self.get_model_descriptions() report_custom_model = gr.Dropdown( choices=list(self.model_name_to_description.keys()), @@ -573,12 +588,20 @@ class GradioInterface: interactive=False ) - # Add information about detail levels + # Add information about detail levels and query types detail_levels_info = "" for level, description in self.detail_level_manager.get_available_detail_levels(): detail_levels_info += f"- **{level}**: {description}\n" + query_types_info = """ + - **auto-detect**: Automatically determine the query type based on the query text + - **factual**: For queries seeking specific information (e.g., "What is...", "How does...") + - **exploratory**: For queries investigating a topic broadly (e.g., "Tell me about...") + - **comparative**: For queries comparing multiple items (e.g., "Compare X and Y", "Differences between...") + """ + gr.Markdown(f"### Detail Levels\n{detail_levels_info}") + gr.Markdown(f"### Query Types\n{query_types_info}") # Set up event handlers search_button.click( @@ -588,8 +611,8 @@ class GradioInterface: ) report_button.click( - fn=lambda q, d, m, r, p: asyncio.run(self.generate_report(q, d, m, r, p)), - inputs=[report_query_input, report_detail_level, report_custom_model, + fn=lambda q, d, t, m, r, p: asyncio.run(self.generate_report(q, d, t, m, r, p)), + inputs=[report_query_input, report_detail_level, report_query_type, report_custom_model, search_file_output, report_process_thinking], outputs=[report_output, report_file_output] )