ira/ui/gradio_interface.py

769 lines
35 KiB
Python

"""
Gradio interface for the intelligent research system.
This module provides a web interface for users to interact with the research system.
"""
import os
import json
import gradio as gr
import sys
import time
import asyncio
from pathlib import Path
from datetime import datetime
# Add the parent directory to the path to allow importing from other modules
sys.path.append(str(Path(__file__).parent.parent))
from query.query_processor import QueryProcessor
from execution.search_executor import SearchExecutor
from execution.result_collector import ResultCollector
from report.report_generator import get_report_generator, initialize_report_generator
from report.report_detail_levels import get_report_detail_level_manager, DetailLevel
from config.config import Config
class GradioInterface:
"""Gradio interface for the intelligent research system."""
def __init__(self):
"""Initialize the Gradio interface."""
self.query_processor = QueryProcessor()
self.search_executor = SearchExecutor()
self.result_collector = ResultCollector()
self.results_dir = Path(__file__).parent.parent / "results"
self.results_dir.mkdir(exist_ok=True)
self.reports_dir = Path(__file__).parent.parent
self.reports_dir.mkdir(exist_ok=True)
self.detail_level_manager = get_report_detail_level_manager()
self.config = Config()
# The report generator will be initialized in the async init method
self.report_generator = None
# Progress tracking elements (will be set in create_interface)
self.report_progress = None
self.report_progress_bar = None
async def async_init(self):
"""Asynchronously initialize components that require async initialization."""
# Initialize the report generator
await initialize_report_generator()
self.report_generator = get_report_generator()
return self
def process_query(self, query, num_results=10, use_reranker=True):
"""
Process a query and return the results.
Args:
query (str): The query to process
num_results (int): Number of results to return
use_reranker (bool): Whether to use the Jina Reranker for semantic ranking
Returns:
tuple: (markdown_results, json_results_path)
"""
try:
# Process the query
print(f"Processing query: {query}")
processed_query = self.query_processor.process_query(query)
print(f"Processed query: {processed_query}")
# Get available search engines and print their status
available_engines = self.search_executor.get_available_search_engines()
print(f"Available search engines: {available_engines}")
# Check which handlers are actually available
for engine_name, handler in self.search_executor.available_handlers.items():
print(f"Handler {engine_name} available: {handler.is_available()}")
if not handler.is_available():
print(f" - Reason: API key may be missing for {engine_name}")
# Add search engines if not specified
if 'search_engines' not in processed_query:
processed_query['search_engines'] = available_engines
print(f"Using search engines: {available_engines}")
# Execute the search - request more results from each engine
print(f"Executing search...")
search_results = self.search_executor.execute_search(
structured_query=processed_query,
num_results=num_results
)
# Print which engines returned results
for engine, results in search_results.items():
print(f"Engine {engine} returned {len(results)} results")
# Add the query to each result for reranking
enhanced_query = processed_query.get("enhanced_query", processed_query.get("original_query", query))
# Flatten results for easier manipulation
flattened_results = []
for engine, results in search_results.items():
for result in results:
# Add the query and engine to each result
result["query"] = enhanced_query
result["engine"] = engine
flattened_results.append(result)
# Process the results - don't limit the number of results
print(f"Processing results...")
processed_results = self.result_collector.process_results(
{"combined": flattened_results}, dedup=True, max_results=None, use_reranker=use_reranker
)
print(f"Processed {len(processed_results)} results")
# Save results to file
timestamp = int(time.time())
results_file = self.results_dir / f"results_{timestamp}.json"
# Ensure the results are not empty before saving
if processed_results:
with open(results_file, "w") as f:
json.dump(processed_results, f, indent=2)
print(f"Results saved to {results_file}")
file_path = str(results_file)
else:
error_message = "No results found. Please try a different query or check API keys."
print(error_message)
file_path = None
return f"## No Results Found\n\n{error_message}", file_path
# Format results for display
markdown_results = self._format_results_as_markdown(processed_results)
return markdown_results, file_path
except Exception as e:
error_message = f"Error processing query: {str(e)}"
print(f"ERROR: {error_message}")
import traceback
traceback.print_exc()
return f"## Error\n\n{error_message}", None
def _format_results_as_markdown(self, results):
"""
Format results as markdown.
Args:
results (list): List of result dictionaries
Returns:
str: Markdown formatted results
"""
if not results:
return "## No Results Found\n\nNo results were found for your query."
# Count results by source
source_counts = {}
for result in results:
source = result.get("source", "unknown")
source_counts[source] = source_counts.get(source, 0) + 1
# Create source distribution string
source_distribution = ", ".join([f"{source}: {count}" for source, count in source_counts.items()])
markdown = f"## Search Results\n\n"
markdown += f"*Sources: {source_distribution}*\n\n"
for i, result in enumerate(results):
title = result.get("title", "Untitled")
url = result.get("url", "")
snippet = result.get("snippet", "No snippet available")
source = result.get("source", "unknown")
authors = result.get("authors", "Unknown")
year = result.get("year", "Unknown")
score = result.get("relevance_score", 0)
markdown += f"### {i+1}. {title}\n\n"
markdown += f"**Source**: {source}\n\n"
markdown += f"**URL**: [{url}]({url})\n\n"
markdown += f"**Snippet**: {snippet}\n\n"
markdown += f"**Authors**: {authors}\n\n"
markdown += f"**Year**: {year}\n\n"
markdown += f"**Score**: {score}\n\n"
markdown += "---\n\n"
return markdown
async def generate_report(self, query, detail_level="standard", query_type="auto-detect", custom_model=None,
results_file=None, process_thinking_tags=False, initial_results=10, final_results=7,
progress=gr.Progress()):
"""
Generate a report for the given query.
Args:
query: The query to generate a report for
detail_level: The level of detail for the report (brief, standard, detailed, comprehensive)
custom_model: Custom model to use for report generation
results_file: Path to a file containing search results
process_thinking_tags: Whether to process thinking tags in the model output
progress: Gradio progress indicator
Returns:
Path to the generated report
"""
try:
# Create a timestamped output file
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_suffix = ""
# Extract the actual model name from the description if selected
if custom_model:
# If the model is in the format "model_name (provider: model_display)"
if "(" in custom_model:
custom_model = custom_model.split(" (")[0]
model_name = custom_model.split('/')[-1]
model_suffix = f"_{model_name}"
output_file = self.reports_dir / f"report_{timestamp}{model_suffix}.md"
# Get detail level configuration
config = self.detail_level_manager.get_detail_level_config(detail_level)
# Override num_results if provided
if initial_results:
config["initial_results_per_engine"] = initial_results
# Set final results after reranking if provided
if final_results:
config["final_results_after_reranking"] = final_results
# If custom model is provided, use it
if custom_model:
config["model"] = custom_model
# Ensure report generator is initialized
if self.report_generator is None:
print("Initializing report generator...")
await initialize_report_generator()
self.report_generator = get_report_generator()
# This will update the report synthesizer to use the custom model
self.report_generator.set_detail_level(detail_level)
print(f"Generating report with detail level: {detail_level}")
print(f"Detail level configuration: {config}")
print(f"Using model: {config['model']}")
print(f"Processing thinking tags: {process_thinking_tags}")
# If results file is provided, load results from it
search_results = []
if results_file and os.path.exists(results_file):
with open(results_file, 'r') as f:
search_results = json.load(f)
print(f"Loaded {len(search_results)} results from {results_file}")
else:
# If no results file is provided, perform a search
print(f"No results file provided, performing search for: {query}")
# Process the query to create a structured query
structured_query = await self.query_processor.process_query(query)
# Generate search queries for different engines
structured_query = await self.query_processor.generate_search_queries(
structured_query,
self.search_executor.get_available_search_engines()
)
# Execute the search with the structured query
# Use initial_results_per_engine if available, otherwise fall back to num_results
num_results_to_fetch = config.get("initial_results_per_engine", config.get("num_results", 10))
search_results_dict = self.search_executor.execute_search(
structured_query,
num_results=num_results_to_fetch
)
# Add debug logging
print(f"Search results by engine:")
for engine, results in search_results_dict.items():
print(f" {engine}: {len(results)} results")
# Flatten the search results
search_results = []
for engine_results in search_results_dict.values():
search_results.extend(engine_results)
print(f"Total flattened search results: {len(search_results)}")
# Fallback mechanism if no search results are found
if len(search_results) == 0:
print("WARNING: No search results found. Using fallback search mechanism...")
# Try a simplified version of the query
simplified_query = query.split(" ")[:10] # Take first 10 words
simplified_query = " ".join(simplified_query)
if simplified_query != query:
print(f"Trying simplified query: {simplified_query}")
# Create a basic structured query
basic_structured_query = {
"original_query": simplified_query,
"enhanced_query": simplified_query,
"type": "unknown",
"intent": "research"
}
# Try search again with simplified query
search_results_dict = self.search_executor.execute_search(
basic_structured_query,
num_results=config["num_results"]
)
# Flatten the search results
search_results = []
for engine_results in search_results_dict.values():
search_results.extend(engine_results)
print(f"Fallback search returned {len(search_results)} results")
# Second fallback: If still no results, create a mock result to prevent report generation failure
if len(search_results) == 0:
print("WARNING: Fallback search also failed. Creating mock search result...")
# Create a mock search result with the query as the title
search_results = [{
"title": f"Information about: {query}",
"url": "https://example.com/search-result",
"snippet": f"This is a placeholder result for the query: {query}. " +
"The search system was unable to find relevant results. " +
"Please try refining your query or check your search API configuration.",
"source": "mock_result",
"score": 1.0
}]
print("Created mock search result to allow report generation to proceed")
# Rerank results if we have a reranker
if hasattr(self, 'reranker') and self.reranker:
# Use final_results_after_reranking if available, otherwise fall back to num_results
top_n_results = config.get("final_results_after_reranking", config.get("num_results", 7))
search_results = self.reranker.rerank_with_metadata(
query,
search_results,
document_key='snippet',
top_n=top_n_results
)
# Set up progress tracking
# Define progress callback function
def progress_callback(current_progress, total_chunks, current_report):
# Calculate current chunk number
current_chunk = int(current_progress * total_chunks) if total_chunks > 0 else 0
# Determine the status message based on progress
if current_progress == 0:
status_message = "Preparing documents..."
elif current_progress >= 1.0:
status_message = "Finalizing report..."
else:
status_message = f"Processing chunk {current_chunk}/{total_chunks}..."
# Add current chunk title if available
if hasattr(self.report_generator, 'current_chunk_title'):
chunk_title = self.report_generator.current_chunk_title
if chunk_title:
status_message += f" ({chunk_title})"
# Update the progress status directly
return status_message
# Set the progress callback for the report generator
if hasattr(self.report_generator, 'set_progress_callback'):
# Create a wrapper function that updates the UI elements
def ui_progress_callback(current_progress, total_chunks, current_report):
status_message = progress_callback(current_progress, total_chunks, current_report)
# Use Gradio's built-in progress tracking mechanism
# This will properly update the UI during async operations
progress(current_progress, desc=status_message)
# Also update our custom UI elements
self.report_progress.value = status_message
self.report_progress_bar.value = int(current_progress * 100)
return status_message
self.report_generator.set_progress_callback(ui_progress_callback)
# Generate the report
print(f"Generating report with {len(search_results)} search results")
if len(search_results) == 0:
print("WARNING: No search results found. Report generation may fail.")
# Update progress status based on detail level
if detail_level.lower() == "comprehensive":
self.progress_status = "Generating progressive report..."
else:
self.progress_status = "Processing document chunks..."
# Set up initial progress state
self.report_progress.value = "Preparing documents..."
self.report_progress_bar.value = 0
# Handle query_type parameter
actual_query_type = None
if query_type != "auto-detect":
actual_query_type = query_type
print(f"Using user-selected query type: {actual_query_type}")
else:
print("Using auto-detection for query type")
report = await self.report_generator.generate_report(
search_results=search_results,
query=query,
token_budget=config["token_budget"],
chunk_size=config["chunk_size"],
overlap_size=config["overlap_size"],
detail_level=detail_level,
query_type=actual_query_type
)
# Final progress update
progress(1.0)
# Process thinking tags if requested
if process_thinking_tags:
report = self._process_thinking_tags(report)
# Save report to file
with open(output_file, 'w', encoding='utf-8') as f:
f.write(report)
print(f"Report saved to: {output_file}")
return report, str(output_file)
except Exception as e:
error_message = f"Error generating report: {str(e)}"
print(f"ERROR: {error_message}")
import traceback
traceback.print_exc()
return f"## Error\n\n{error_message}", None
def _process_thinking_tags(self, text):
"""
Process thinking tags in the text.
Args:
text (str): Text to process
Returns:
str: Processed text
"""
# Remove content between <thinking> and </thinking> tags
import re
return re.sub(r'<thinking>.*?</thinking>', '', text, flags=re.DOTALL)
def get_available_models(self):
"""
Get a list of available models for report generation.
Returns:
list: List of available model names
"""
# Get models from config
models = []
# Extract all model names from the config file
if 'models' in self.config.config_data:
models = list(self.config.config_data['models'].keys())
# If no models found, provide some defaults
if not models:
models = [
"llama-3.1-8b-instant",
"llama-3.3-70b-versatile",
"groq/deepseek-r1-distill-llama-70b-specdec",
"openrouter-mixtral",
"openrouter-claude",
"gemini-2.0-flash-lite"
]
return models
def get_model_descriptions(self):
"""
Get descriptions for available models.
Returns:
dict: Dictionary mapping model names to descriptions
"""
descriptions = {}
model_name_to_description = {}
if 'models' in self.config.config_data:
for model_name, model_config in self.config.config_data['models'].items():
provider = model_config.get('provider', 'unknown')
model_display = model_config.get('model_name', model_name)
max_tokens = model_config.get('max_tokens', 'unknown')
temperature = model_config.get('temperature', 'unknown')
# Create a description that includes the provider and actual model name
display_name = f"{model_name} ({provider}: {model_display})"
descriptions[model_name] = display_name
# Create a more detailed description for the dropdown tooltip
detailed_info = f"{display_name} - Max tokens: {max_tokens}, Temperature: {temperature}"
model_name_to_description[display_name] = detailed_info
self.model_name_to_description = model_name_to_description
return descriptions
def create_interface(self):
"""
Create and return the Gradio interface.
Returns:
gr.Blocks: The Gradio interface
"""
with gr.Blocks(title="Intelligent Research System") as interface:
gr.Markdown("# Intelligent Research System")
gr.Markdown(
"""
This system helps you research topics by searching across multiple sources
including Google (via Serper), Google Scholar, arXiv, and news sources.
You can either search for results or generate a comprehensive report.
**Special Capabilities:**
- Automatically detects and optimizes current events queries
- Specialized search handlers for different types of information
- Semantic ranking for the most relevant results
"""
)
with gr.Tabs() as tabs:
with gr.TabItem("Search"):
with gr.Row():
with gr.Column(scale=4):
search_query_input = gr.Textbox(
label="Research Query",
placeholder="Enter your research question here...",
lines=3
)
with gr.Column(scale=1):
search_num_results = gr.Slider(
minimum=5,
maximum=50,
value=20,
step=5,
label="Results Per Engine"
)
search_use_reranker = gr.Checkbox(
label="Use Semantic Reranker",
value=True,
info="Uses Jina AI's reranker for more relevant results"
)
search_button = gr.Button("Search", variant="primary")
gr.Examples(
examples=[
["What are the latest advancements in quantum computing?"],
["Compare transformer and RNN architectures for NLP tasks"],
["Explain the environmental impact of electric vehicles"],
["What recent actions has Trump taken regarding tariffs?"],
["What are the recent papers on large language model alignment?"],
["What are the main research findings on climate change adaptation strategies in agriculture?"]
],
inputs=search_query_input
)
with gr.Row():
with gr.Column():
search_results_output = gr.Markdown(label="Results")
with gr.Row():
with gr.Column():
search_file_output = gr.Textbox(
label="Results saved to file",
interactive=False
)
with gr.TabItem("Generate Report"):
with gr.Row():
with gr.Column(scale=4):
report_query_input = gr.Textbox(
label="Research Query",
placeholder="Enter your research question here...",
lines=3
)
with gr.Column(scale=1):
report_detail_level = gr.Dropdown(
choices=["brief", "standard", "detailed", "comprehensive"],
value="standard",
label="Detail Level",
info="Controls the depth and breadth of the report"
)
report_query_type = gr.Dropdown(
choices=["auto-detect", "factual", "exploratory", "comparative", "code"],
value="auto-detect",
label="Query Type",
info="Type of query determines the report structure"
)
model_descriptions = self.get_model_descriptions()
report_custom_model = gr.Dropdown(
choices=list(self.model_name_to_description.keys()),
value=None,
label="Custom Model (Optional)",
info="Select a custom model for report generation"
)
with gr.Row():
with gr.Column():
gr.Markdown("### Advanced Settings")
with gr.Row():
with gr.Column():
with gr.Accordion("Search Parameters", open=False):
with gr.Row():
initial_results_slider = gr.Slider(
minimum=5,
maximum=50,
value=10,
step=5,
label="Initial Results Per Engine",
info="Number of results to fetch from each search engine"
)
final_results_slider = gr.Slider(
minimum=3,
maximum=30,
value=7,
step=1,
label="Final Results After Reranking",
info="Number of results to keep after reranking"
)
with gr.Accordion("Processing Options", open=False):
with gr.Row():
report_process_thinking = gr.Checkbox(
label="Process Thinking Tags",
value=False,
info="Process <thinking> tags in model output"
)
with gr.Row():
report_button = gr.Button("Generate Report", variant="primary", size="lg")
with gr.Row():
with gr.Column():
# Progress indicator that will be updated by the progress callback
self.report_progress = gr.Textbox(
label="Progress Status",
value="Ready",
interactive=False
)
with gr.Row():
with gr.Column():
# Progress bar to show visual progress
self.report_progress_bar = gr.Slider(
minimum=0,
maximum=100,
value=0,
step=1,
label="Progress",
interactive=False
)
gr.Examples(
examples=[
["What are the latest advancements in quantum computing?"],
["Compare transformer and RNN architectures for NLP tasks"],
["Explain the environmental impact of electric vehicles"],
["Explain the potential relationship between creatine supplementation and muscle loss due to GLP1-ar drugs for weight loss."],
["What recent actions has Trump taken regarding tariffs?"],
["What are the recent papers on large language model alignment?"],
["What are the main research findings on climate change adaptation strategies in agriculture?"]
],
inputs=report_query_input
)
with gr.Row():
with gr.Column():
report_output = gr.Markdown(label="Generated Report")
with gr.Row():
with gr.Column():
report_file_output = gr.Textbox(
label="Report saved to file",
interactive=False
)
# Add information about detail levels and query types
detail_levels_info = ""
for level, description in self.detail_level_manager.get_available_detail_levels():
detail_levels_info += f"- **{level}**: {description}\n"
query_types_info = """
- **auto-detect**: Automatically determine the query type based on the query text
- **factual**: For queries seeking specific information (e.g., "What is...", "How does...")
- **exploratory**: For queries investigating a topic broadly (e.g., "Tell me about...")
- **comparative**: For queries comparing multiple items (e.g., "Compare X and Y", "Differences between...")
- **code**: For queries related to programming, software development, or technical implementation
"""
gr.Markdown(f"### Detail Levels\n{detail_levels_info}")
gr.Markdown(f"### Query Types\n{query_types_info}")
# Set up event handlers
search_button.click(
fn=self.process_query,
inputs=[search_query_input, search_num_results, search_use_reranker],
outputs=[search_results_output, search_file_output]
)
# Connect the progress callback to the report button
def update_progress_display(progress_value, status_message):
percentage = int(progress_value * 100)
return status_message, percentage
# Update the progress tracking in the generate_report method
async def generate_report_with_progress(query, detail_level, query_type, model_name, rerank, token_budget, initial_results, final_results):
# Set up progress tracking
progress_data = gr.Progress(track_tqdm=True)
# Call the original generate_report method
result = await self.generate_report(query, detail_level, query_type, model_name, rerank, token_budget, initial_results, final_results)
return result
report_button.click(
fn=lambda q, d, t, m, r, p, i, f: asyncio.run(generate_report_with_progress(q, d, t, m, r, p, i, f)),
inputs=[report_query_input, report_detail_level, report_query_type, report_custom_model,
search_file_output, report_process_thinking, initial_results_slider, final_results_slider],
outputs=[report_output, report_file_output]
)
return interface
def launch(self, **kwargs):
"""
Launch the Gradio interface.
Args:
**kwargs: Keyword arguments to pass to gr.Interface.launch()
"""
interface = self.create_interface()
interface.launch(**kwargs)
def main():
"""Main function to launch the Gradio interface."""
# Create interface and initialize async components
interface = GradioInterface()
# Run the async initialization in the event loop
loop = asyncio.get_event_loop()
loop.run_until_complete(interface.async_init())
# Launch the interface
interface.launch(share=True)
if __name__ == "__main__":
main()