Add query type selection to Gradio UI and improve report generation

This commit is contained in:
Steve White 2025-03-12 12:09:08 -05:00
parent c8c5240657
commit bf49474ca6
2 changed files with 38 additions and 6 deletions

View File

@ -207,7 +207,8 @@ class ReportGenerator:
token_budget: Optional[int] = None, token_budget: Optional[int] = None,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
overlap_size: Optional[int] = None, overlap_size: Optional[int] = None,
detail_level: Optional[str] = None) -> str: detail_level: Optional[str] = None,
query_type: Optional[str] = None) -> str:
""" """
Generate a report from search results. Generate a report from search results.
@ -234,6 +235,12 @@ class ReportGenerator:
overlap_size overlap_size
) )
# Log query type information
if query_type:
logger.info(f"Using specified query type: {query_type}")
else:
logger.info("Using automatic query type detection")
# Choose the appropriate synthesizer based on detail level # Choose the appropriate synthesizer based on detail level
if self.detail_level.lower() == "comprehensive": if self.detail_level.lower() == "comprehensive":
# Use progressive report synthesizer for comprehensive detail level # Use progressive report synthesizer for comprehensive detail level
@ -241,6 +248,7 @@ class ReportGenerator:
report = await self.progressive_report_synthesizer.synthesize_report( report = await self.progressive_report_synthesizer.synthesize_report(
selected_chunks, selected_chunks,
query, query,
query_type=query_type,
detail_level=self.detail_level detail_level=self.detail_level
) )
else: else:
@ -249,6 +257,7 @@ class ReportGenerator:
report = await self.report_synthesizer.synthesize_report( report = await self.report_synthesizer.synthesize_report(
selected_chunks, selected_chunks,
query, query,
query_type=query_type,
detail_level=self.detail_level detail_level=self.detail_level
) )

View File

@ -184,7 +184,7 @@ class GradioInterface:
return markdown return markdown
async def generate_report(self, query, detail_level="standard", custom_model=None, async def generate_report(self, query, detail_level="standard", query_type="auto-detect", custom_model=None,
results_file=None, process_thinking_tags=False, progress=gr.Progress()): results_file=None, process_thinking_tags=False, progress=gr.Progress()):
""" """
Generate a report for the given query. Generate a report for the given query.
@ -361,13 +361,22 @@ class GradioInterface:
# Initial progress update # Initial progress update
progress(0) progress(0)
# Handle query_type parameter
actual_query_type = None
if query_type != "auto-detect":
actual_query_type = query_type
print(f"Using user-selected query type: {actual_query_type}")
else:
print("Using auto-detection for query type")
report = await self.report_generator.generate_report( report = await self.report_generator.generate_report(
search_results=search_results, search_results=search_results,
query=query, query=query,
token_budget=config["token_budget"], token_budget=config["token_budget"],
chunk_size=config["chunk_size"], chunk_size=config["chunk_size"],
overlap_size=config["overlap_size"], overlap_size=config["overlap_size"],
detail_level=detail_level detail_level=detail_level,
query_type=actual_query_type
) )
# Final progress update # Final progress update
@ -538,6 +547,12 @@ class GradioInterface:
label="Detail Level", label="Detail Level",
info="Controls the depth and breadth of the report" info="Controls the depth and breadth of the report"
) )
report_query_type = gr.Dropdown(
choices=["auto-detect", "factual", "exploratory", "comparative"],
value="auto-detect",
label="Query Type",
info="Type of query determines the report structure"
)
model_descriptions = self.get_model_descriptions() model_descriptions = self.get_model_descriptions()
report_custom_model = gr.Dropdown( report_custom_model = gr.Dropdown(
choices=list(self.model_name_to_description.keys()), choices=list(self.model_name_to_description.keys()),
@ -573,12 +588,20 @@ class GradioInterface:
interactive=False interactive=False
) )
# Add information about detail levels # Add information about detail levels and query types
detail_levels_info = "" detail_levels_info = ""
for level, description in self.detail_level_manager.get_available_detail_levels(): for level, description in self.detail_level_manager.get_available_detail_levels():
detail_levels_info += f"- **{level}**: {description}\n" detail_levels_info += f"- **{level}**: {description}\n"
query_types_info = """
- **auto-detect**: Automatically determine the query type based on the query text
- **factual**: For queries seeking specific information (e.g., "What is...", "How does...")
- **exploratory**: For queries investigating a topic broadly (e.g., "Tell me about...")
- **comparative**: For queries comparing multiple items (e.g., "Compare X and Y", "Differences between...")
"""
gr.Markdown(f"### Detail Levels\n{detail_levels_info}") gr.Markdown(f"### Detail Levels\n{detail_levels_info}")
gr.Markdown(f"### Query Types\n{query_types_info}")
# Set up event handlers # Set up event handlers
search_button.click( search_button.click(
@ -588,8 +611,8 @@ class GradioInterface:
) )
report_button.click( report_button.click(
fn=lambda q, d, m, r, p: asyncio.run(self.generate_report(q, d, m, r, p)), fn=lambda q, d, t, m, r, p: asyncio.run(self.generate_report(q, d, t, m, r, p)),
inputs=[report_query_input, report_detail_level, report_custom_model, inputs=[report_query_input, report_detail_level, report_query_type, report_custom_model,
search_file_output, report_process_thinking], search_file_output, report_process_thinking],
outputs=[report_output, report_file_output] outputs=[report_output, report_file_output]
) )