ira/execution/result_collector.py

375 lines
13 KiB
Python

"""
Result collector module.
Processes and organizes search results from multiple search engines.
"""
import os
import json
import time
from typing import Dict, List, Any, Optional, Set
from urllib.parse import urlparse
from datetime import datetime
from ranking.jina_reranker import get_jina_reranker
class ResultCollector:
"""
Collects and processes search results from multiple search engines.
Handles deduplication, merging, and filtering of results.
"""
def __init__(self):
"""Initialize the result collector."""
try:
self.reranker = get_jina_reranker()
self.reranker_available = True
except ValueError:
print("Jina Reranker not available. Will use basic scoring instead.")
self.reranker_available = False
def process_results(self,
search_results: Dict[str, List[Dict[str, Any]]],
dedup: bool = True,
max_results: Optional[int] = None,
use_reranker: bool = True) -> List[Dict[str, Any]]:
"""
Process search results from multiple search engines.
Args:
search_results: Dictionary mapping search engine names to lists of search results
dedup: Whether to deduplicate results based on URL
max_results: Maximum number of results to return (after processing)
use_reranker: Whether to use the Jina Reranker for semantic ranking
Returns:
List of processed search results
"""
# Flatten and normalize results
all_results = self._flatten_results(search_results)
# Deduplicate results if requested
if dedup:
all_results = self._deduplicate_results(all_results)
# Use reranker if available and requested, otherwise use basic scoring
if use_reranker and self.reranker_available:
all_results = self._rerank_results(all_results)
else:
# Sort results by relevance (using a simple scoring algorithm)
all_results = self._score_and_sort_results(all_results)
# Limit results if requested
if max_results is not None:
all_results = all_results[:max_results]
return all_results
def _flatten_results(self, search_results: Dict[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
"""
Flatten results from multiple search engines into a single list.
Args:
search_results: Dictionary mapping search engine names to lists of search results
Returns:
Flattened list of search results
"""
all_results = []
for engine, results in search_results.items():
for result in results:
# Ensure all results have the same basic structure
normalized_result = {
"title": result.get("title", ""),
"url": result.get("url", ""),
"snippet": result.get("snippet", ""),
"source": result.get("source", engine),
"domain": self._extract_domain(result.get("url", "")),
"timestamp": datetime.now().isoformat(),
"raw_data": result
}
all_results.append(normalized_result)
return all_results
def _deduplicate_results(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Deduplicate results based on URL.
Args:
results: List of search results
Returns:
Deduplicated list of search results
"""
seen_urls = set()
deduplicated_results = []
for result in results:
url = result.get("url", "")
# Normalize URL for comparison
normalized_url = self._normalize_url(url)
if normalized_url and normalized_url not in seen_urls:
seen_urls.add(normalized_url)
deduplicated_results.append(result)
return deduplicated_results
def _score_and_sort_results(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Score and sort results by relevance.
Args:
results: List of search results
Returns:
Sorted list of search results
"""
# Add a score to each result
for result in results:
score = 0
# Boost score based on source (e.g., scholarly sources get higher scores)
source = result.get("source", "")
if source == "scholar":
score += 10
elif source == "serper":
score += 9
elif source == "arxiv":
score += 8
elif source == "google":
score += 5
# Boost score based on position in original results
position = result.get("raw_data", {}).get("position", 0)
if position > 0:
score += max(0, 10 - position)
# Boost score for results with more content
snippet_length = len(result.get("snippet", ""))
if snippet_length > 200:
score += 3
elif snippet_length > 100:
score += 2
elif snippet_length > 50:
score += 1
# Store the score
result["relevance_score"] = score
# Sort by score (descending)
sorted_results = sorted(results, key=lambda x: x.get("relevance_score", 0), reverse=True)
return sorted_results
def _rerank_results(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Rerank results using the Jina Reranker.
Args:
results: List of search results
Returns:
Reranked list of search results
"""
if not results:
return []
# Get the original query from the first result (all should have the same query)
query = results[0].get("query", "")
if not query:
# If no query is found, use a fallback approach
print("Warning: No query found in results. Using basic scoring instead.")
return self._score_and_sort_results(results)
# Extract snippets for reranking
snippets = []
for result in results:
# Combine title and snippet for better reranking
content = f"{result.get('title', '')} {result.get('snippet', '')}"
snippets.append(content)
try:
# Use the reranker to rerank the snippets
reranked = self.reranker.rerank(query, snippets)
# Create a new list of results based on the reranking
reranked_results = []
for item in reranked:
# Get the original result and add the new score
original_result = results[item['index']]
new_result = original_result.copy()
new_result['relevance_score'] = item['score']
reranked_results.append(new_result)
return reranked_results
except Exception as e:
print(f"Error reranking results: {str(e)}")
# Fall back to basic scoring if reranking fails
return self._score_and_sort_results(results)
def _extract_domain(self, url: str) -> str:
"""
Extract the domain from a URL.
Args:
url: URL to extract domain from
Returns:
Domain name
"""
try:
parsed_url = urlparse(url)
domain = parsed_url.netloc
# Remove 'www.' prefix if present
if domain.startswith('www.'):
domain = domain[4:]
return domain
except:
return ""
def _normalize_url(self, url: str) -> str:
"""
Normalize a URL for comparison.
Args:
url: URL to normalize
Returns:
Normalized URL
"""
try:
# Parse the URL
parsed_url = urlparse(url)
# Reconstruct with just the scheme, netloc, and path
normalized = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path}"
# Remove trailing slash if present
if normalized.endswith('/'):
normalized = normalized[:-1]
return normalized.lower()
except:
return url.lower()
def filter_results(self,
results: List[Dict[str, Any]],
filters: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Filter results based on specified criteria.
Args:
results: List of search results
filters: Dictionary of filter criteria:
- domains: List of domains to include or exclude
- exclude_domains: Whether to exclude (True) or include (False) the specified domains
- min_score: Minimum relevance score
- sources: List of sources to include
- date_range: Dictionary with 'start' and 'end' dates
Returns:
Filtered list of search results
"""
filtered_results = results.copy()
# Filter by domains
if "domains" in filters and filters["domains"]:
domains = set(filters["domains"])
exclude_domains = filters.get("exclude_domains", False)
if exclude_domains:
filtered_results = [r for r in filtered_results if r.get("domain", "") not in domains]
else:
filtered_results = [r for r in filtered_results if r.get("domain", "") in domains]
# Filter by minimum score
if "min_score" in filters:
min_score = filters["min_score"]
filtered_results = [r for r in filtered_results if r.get("relevance_score", 0) >= min_score]
# Filter by sources
if "sources" in filters and filters["sources"]:
sources = set(filters["sources"])
filtered_results = [r for r in filtered_results if r.get("source", "") in sources]
# Filter by date range
if "date_range" in filters:
date_range = filters["date_range"]
if "start" in date_range:
start_date = datetime.fromisoformat(date_range["start"])
filtered_results = [
r for r in filtered_results
if "date" not in r or not r["date"] or datetime.fromisoformat(r["date"]) >= start_date
]
if "end" in date_range:
end_date = datetime.fromisoformat(date_range["end"])
filtered_results = [
r for r in filtered_results
if "date" not in r or not r["date"] or datetime.fromisoformat(r["date"]) <= end_date
]
return filtered_results
def group_results_by_domain(self, results: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]:
"""
Group results by domain.
Args:
results: List of search results
Returns:
Dictionary mapping domains to lists of search results
"""
grouped_results = {}
for result in results:
domain = result.get("domain", "unknown")
if domain not in grouped_results:
grouped_results[domain] = []
grouped_results[domain].append(result)
return grouped_results
def save_results(self, results: List[Dict[str, Any]], file_path: str) -> None:
"""
Save search results to a file.
Args:
results: List of search results
file_path: Path to save results to
"""
try:
with open(file_path, 'w') as f:
json.dump(results, f, indent=2)
print(f"Results saved to {file_path}")
except Exception as e:
print(f"Error saving results: {e}")
def load_results(self, file_path: str) -> List[Dict[str, Any]]:
"""
Load search results from a file.
Args:
file_path: Path to load results from
Returns:
List of search results
"""
try:
with open(file_path, 'r') as f:
results = json.load(f)
return results
except Exception as e:
print(f"Error loading results: {e}")
return []