Add query type selection to Gradio UI and improve report generation
This commit is contained in:
parent
c8c5240657
commit
bf49474ca6
|
@ -207,7 +207,8 @@ class ReportGenerator:
|
||||||
token_budget: Optional[int] = None,
|
token_budget: Optional[int] = None,
|
||||||
chunk_size: Optional[int] = None,
|
chunk_size: Optional[int] = None,
|
||||||
overlap_size: Optional[int] = None,
|
overlap_size: Optional[int] = None,
|
||||||
detail_level: Optional[str] = None) -> str:
|
detail_level: Optional[str] = None,
|
||||||
|
query_type: Optional[str] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Generate a report from search results.
|
Generate a report from search results.
|
||||||
|
|
||||||
|
@ -234,6 +235,12 @@ class ReportGenerator:
|
||||||
overlap_size
|
overlap_size
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Log query type information
|
||||||
|
if query_type:
|
||||||
|
logger.info(f"Using specified query type: {query_type}")
|
||||||
|
else:
|
||||||
|
logger.info("Using automatic query type detection")
|
||||||
|
|
||||||
# Choose the appropriate synthesizer based on detail level
|
# Choose the appropriate synthesizer based on detail level
|
||||||
if self.detail_level.lower() == "comprehensive":
|
if self.detail_level.lower() == "comprehensive":
|
||||||
# Use progressive report synthesizer for comprehensive detail level
|
# Use progressive report synthesizer for comprehensive detail level
|
||||||
|
@ -241,6 +248,7 @@ class ReportGenerator:
|
||||||
report = await self.progressive_report_synthesizer.synthesize_report(
|
report = await self.progressive_report_synthesizer.synthesize_report(
|
||||||
selected_chunks,
|
selected_chunks,
|
||||||
query,
|
query,
|
||||||
|
query_type=query_type,
|
||||||
detail_level=self.detail_level
|
detail_level=self.detail_level
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -249,6 +257,7 @@ class ReportGenerator:
|
||||||
report = await self.report_synthesizer.synthesize_report(
|
report = await self.report_synthesizer.synthesize_report(
|
||||||
selected_chunks,
|
selected_chunks,
|
||||||
query,
|
query,
|
||||||
|
query_type=query_type,
|
||||||
detail_level=self.detail_level
|
detail_level=self.detail_level
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -184,7 +184,7 @@ class GradioInterface:
|
||||||
|
|
||||||
return markdown
|
return markdown
|
||||||
|
|
||||||
async def generate_report(self, query, detail_level="standard", custom_model=None,
|
async def generate_report(self, query, detail_level="standard", query_type="auto-detect", custom_model=None,
|
||||||
results_file=None, process_thinking_tags=False, progress=gr.Progress()):
|
results_file=None, process_thinking_tags=False, progress=gr.Progress()):
|
||||||
"""
|
"""
|
||||||
Generate a report for the given query.
|
Generate a report for the given query.
|
||||||
|
@ -361,13 +361,22 @@ class GradioInterface:
|
||||||
# Initial progress update
|
# Initial progress update
|
||||||
progress(0)
|
progress(0)
|
||||||
|
|
||||||
|
# Handle query_type parameter
|
||||||
|
actual_query_type = None
|
||||||
|
if query_type != "auto-detect":
|
||||||
|
actual_query_type = query_type
|
||||||
|
print(f"Using user-selected query type: {actual_query_type}")
|
||||||
|
else:
|
||||||
|
print("Using auto-detection for query type")
|
||||||
|
|
||||||
report = await self.report_generator.generate_report(
|
report = await self.report_generator.generate_report(
|
||||||
search_results=search_results,
|
search_results=search_results,
|
||||||
query=query,
|
query=query,
|
||||||
token_budget=config["token_budget"],
|
token_budget=config["token_budget"],
|
||||||
chunk_size=config["chunk_size"],
|
chunk_size=config["chunk_size"],
|
||||||
overlap_size=config["overlap_size"],
|
overlap_size=config["overlap_size"],
|
||||||
detail_level=detail_level
|
detail_level=detail_level,
|
||||||
|
query_type=actual_query_type
|
||||||
)
|
)
|
||||||
|
|
||||||
# Final progress update
|
# Final progress update
|
||||||
|
@ -538,6 +547,12 @@ class GradioInterface:
|
||||||
label="Detail Level",
|
label="Detail Level",
|
||||||
info="Controls the depth and breadth of the report"
|
info="Controls the depth and breadth of the report"
|
||||||
)
|
)
|
||||||
|
report_query_type = gr.Dropdown(
|
||||||
|
choices=["auto-detect", "factual", "exploratory", "comparative"],
|
||||||
|
value="auto-detect",
|
||||||
|
label="Query Type",
|
||||||
|
info="Type of query determines the report structure"
|
||||||
|
)
|
||||||
model_descriptions = self.get_model_descriptions()
|
model_descriptions = self.get_model_descriptions()
|
||||||
report_custom_model = gr.Dropdown(
|
report_custom_model = gr.Dropdown(
|
||||||
choices=list(self.model_name_to_description.keys()),
|
choices=list(self.model_name_to_description.keys()),
|
||||||
|
@ -573,12 +588,20 @@ class GradioInterface:
|
||||||
interactive=False
|
interactive=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add information about detail levels
|
# Add information about detail levels and query types
|
||||||
detail_levels_info = ""
|
detail_levels_info = ""
|
||||||
for level, description in self.detail_level_manager.get_available_detail_levels():
|
for level, description in self.detail_level_manager.get_available_detail_levels():
|
||||||
detail_levels_info += f"- **{level}**: {description}\n"
|
detail_levels_info += f"- **{level}**: {description}\n"
|
||||||
|
|
||||||
|
query_types_info = """
|
||||||
|
- **auto-detect**: Automatically determine the query type based on the query text
|
||||||
|
- **factual**: For queries seeking specific information (e.g., "What is...", "How does...")
|
||||||
|
- **exploratory**: For queries investigating a topic broadly (e.g., "Tell me about...")
|
||||||
|
- **comparative**: For queries comparing multiple items (e.g., "Compare X and Y", "Differences between...")
|
||||||
|
"""
|
||||||
|
|
||||||
gr.Markdown(f"### Detail Levels\n{detail_levels_info}")
|
gr.Markdown(f"### Detail Levels\n{detail_levels_info}")
|
||||||
|
gr.Markdown(f"### Query Types\n{query_types_info}")
|
||||||
|
|
||||||
# Set up event handlers
|
# Set up event handlers
|
||||||
search_button.click(
|
search_button.click(
|
||||||
|
@ -588,8 +611,8 @@ class GradioInterface:
|
||||||
)
|
)
|
||||||
|
|
||||||
report_button.click(
|
report_button.click(
|
||||||
fn=lambda q, d, m, r, p: asyncio.run(self.generate_report(q, d, m, r, p)),
|
fn=lambda q, d, t, m, r, p: asyncio.run(self.generate_report(q, d, t, m, r, p)),
|
||||||
inputs=[report_query_input, report_detail_level, report_custom_model,
|
inputs=[report_query_input, report_detail_level, report_query_type, report_custom_model,
|
||||||
search_file_output, report_process_thinking],
|
search_file_output, report_process_thinking],
|
||||||
outputs=[report_output, report_file_output]
|
outputs=[report_output, report_file_output]
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue