ira/execution/result_collector.py

455 lines
17 KiB
Python

"""
Result collector module.
Processes and organizes search results from multiple search engines.
"""
import os
import json
import time
from typing import Dict, List, Any, Optional, Set
from urllib.parse import urlparse
from datetime import datetime
from ranking.jina_reranker import get_jina_reranker
class ResultCollector:
"""
Collects and processes search results from multiple search engines.
Handles deduplication, merging, and filtering of results.
"""
def __init__(self):
"""Initialize the result collector."""
try:
self.reranker = get_jina_reranker()
self.reranker_available = True
except ValueError:
print("Jina Reranker not available. Will use basic scoring instead.")
self.reranker_available = False
# Initialize result enrichers
try:
from .result_enrichers.unpaywall_enricher import UnpaywallEnricher
self.unpaywall_enricher = UnpaywallEnricher()
self.unpaywall_available = True
except (ImportError, ValueError):
print("Unpaywall enricher not available. Will not enrich results with open access links.")
self.unpaywall_available = False
def process_results(self,
search_results: Dict[str, List[Dict[str, Any]]],
dedup: bool = True,
max_results: Optional[int] = None,
use_reranker: bool = True) -> List[Dict[str, Any]]:
"""
Process search results from multiple search engines.
Args:
search_results: Dictionary mapping search engine names to lists of results
dedup: Whether to deduplicate results
max_results: Maximum number of results to return
use_reranker: Whether to use the reranker for semantic ranking
Returns:
List of processed search results
"""
# Flatten results from all search engines
flattened_results = []
for engine, results in search_results.items():
for result in results:
# Add the source to each result
result['source'] = engine
flattened_results.append(result)
# Print a verification of the query in the flattened results
if flattened_results:
first_result = flattened_results[0]
query = first_result.get('query', '')
print(f"Verifying query in flattened results:")
print(f"Query in first result: {query[:50]}...")
# Deduplicate results if requested
if dedup:
flattened_results = self._deduplicate_results(flattened_results)
print(f"Processing {len(flattened_results)} combined results")
if dedup:
print(f"Deduplicated to {len(flattened_results)} results")
# Enrich results with open access links if available
is_academic_query = any(result.get("source") in ["openalex", "core", "arxiv", "scholar"] for result in flattened_results)
if is_academic_query and hasattr(self, 'unpaywall_enricher') and self.unpaywall_available:
print("Enriching academic results with open access information")
try:
flattened_results = self.unpaywall_enricher.enrich_results(flattened_results)
print("Results enriched with open access information")
except Exception as e:
print(f"Error enriching results with Unpaywall: {str(e)}")
# Apply reranking if requested and available
if use_reranker and self.reranker is not None:
print("Using Jina Reranker for semantic ranking")
try:
reranked_results = self._rerank_results(flattened_results)
print(f"Reranked {len(reranked_results)} results")
processed_results = reranked_results
except Exception as e:
print(f"Error during reranking: {str(e)}. Falling back to basic scoring.")
print("Using basic scoring")
processed_results = self._score_and_sort_results(flattened_results)
else:
print("Using basic scoring")
processed_results = self._score_and_sort_results(flattened_results)
# Limit the number of results if requested
if max_results is not None:
processed_results = processed_results[:max_results]
print(f"Processed {len(processed_results)} results {'with' if use_reranker and self.reranker is not None else 'without'} reranking")
return processed_results
def _flatten_results(self, search_results: Dict[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
"""
Flatten search results from multiple search engines into a single list.
Args:
search_results: Dictionary mapping search engine names to lists of results
Returns:
Flattened list of search results
"""
# This method is deprecated and kept for backward compatibility
# The process_results method now handles flattened results directly
all_results = []
# Check if we have a flattened structure (single key with all results)
if len(search_results) == 1 and "combined" in search_results:
return search_results["combined"]
# Traditional structure with separate engines
for engine, results in search_results.items():
for result in results:
# Add the source if not already present
if "source" not in result:
result["source"] = engine
all_results.append(result)
return all_results
def _deduplicate_results(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Deduplicate results based on URL.
Args:
results: List of search results
Returns:
Deduplicated list of search results
"""
seen_urls = set()
deduplicated_results = []
for result in results:
url = result.get("url", "")
# Normalize URL for comparison
normalized_url = self._normalize_url(url)
if normalized_url and normalized_url not in seen_urls:
seen_urls.add(normalized_url)
deduplicated_results.append(result)
return deduplicated_results
def _score_and_sort_results(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Score and sort results by relevance.
Args:
results: List of search results
Returns:
Sorted list of search results
"""
# Add a score to each result
for result in results:
score = 0
# Boost score based on source (e.g., scholarly sources get higher scores)
source = result.get("source", "")
if source == "scholar":
score += 10
elif source == "openalex":
score += 10 # Top priority for academic queries
elif source == "core":
score += 9 # High priority for open access academic content
elif source == "arxiv":
score += 8 # Good for preprints and specific fields
elif source == "github":
score += 9 # High priority for code/programming queries
elif source.startswith("stackexchange"):
score += 10 # Top priority for code/programming questions
elif source == "serper":
score += 7 # General web search
elif source == "news":
score += 8 # Good for current events
elif source == "google":
score += 5 # Generic search
# Boost score based on position in original results
position = result.get("raw_data", {}).get("position", 0)
if position > 0:
score += max(0, 10 - position)
# Boost score for results with more content
snippet_length = len(result.get("snippet", ""))
if snippet_length > 200:
score += 3
elif snippet_length > 100:
score += 2
elif snippet_length > 50:
score += 1
# Store the score
result["relevance_score"] = score
# Sort by score (descending)
sorted_results = sorted(results, key=lambda x: x.get("relevance_score", 0), reverse=True)
return sorted_results
def _rerank_results(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Rerank results using the Jina Reranker.
Args:
results: List of search results
Returns:
Reranked list of search results
"""
if not results:
return []
# Extract the original query
# First try to get it from the first result
query = ""
for result in results:
if "query" in result:
query = result.get("query", "")
break
if not query:
# If no query is found, use a fallback approach
print("Warning: No query found in results. Using basic scoring instead.")
return self._score_and_sort_results(results)
print(f"Reranking with query: {query}")
# Extract snippets for reranking
snippets = []
for result in results:
# Combine title and snippet for better reranking
content = f"{result.get('title', '')} {result.get('snippet', '')}"
snippets.append(content)
try:
# Use the reranker to rerank the snippets
reranked = self.reranker.rerank(query, snippets)
if not reranked:
print("Reranker returned empty results. Using basic scoring instead.")
return self._score_and_sort_results(results)
print(f"Reranked {len(reranked)} results")
# Create a new list of results based on the reranking
reranked_results = []
for item in reranked:
# Get the original result and add the new score
index = item.get('index')
score = item.get('score')
if index is None or score is None or index >= len(results):
print(f"Warning: Invalid reranker result item: {item}")
continue
original_result = results[index]
new_result = original_result.copy()
new_result['relevance_score'] = float(score) * 10 # Scale up the score for consistency
reranked_results.append(new_result)
# If we didn't get any valid results, fall back to basic scoring
if not reranked_results:
print("No valid reranked results. Using basic scoring instead.")
return self._score_and_sort_results(results)
return reranked_results
except Exception as e:
print(f"Error reranking results: {str(e)}")
# Fall back to basic scoring if reranking fails
return self._score_and_sort_results(results)
def _extract_domain(self, url: str) -> str:
"""
Extract the domain from a URL.
Args:
url: URL to extract domain from
Returns:
Domain name
"""
try:
parsed_url = urlparse(url)
domain = parsed_url.netloc
# Remove 'www.' prefix if present
if domain.startswith('www.'):
domain = domain[4:]
return domain
except:
return ""
def _normalize_url(self, url: str) -> str:
"""
Normalize a URL for comparison.
Args:
url: URL to normalize
Returns:
Normalized URL
"""
try:
# Parse the URL
parsed_url = urlparse(url)
# Reconstruct with just the scheme, netloc, and path
normalized = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path}"
# Remove trailing slash if present
if normalized.endswith('/'):
normalized = normalized[:-1]
return normalized.lower()
except:
return url.lower()
def filter_results(self,
results: List[Dict[str, Any]],
filters: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Filter results based on specified criteria.
Args:
results: List of search results
filters: Dictionary of filter criteria:
- domains: List of domains to include or exclude
- exclude_domains: Whether to exclude (True) or include (False) the specified domains
- min_score: Minimum relevance score
- sources: List of sources to include
- date_range: Dictionary with 'start' and 'end' dates
Returns:
Filtered list of search results
"""
filtered_results = results.copy()
# Filter by domains
if "domains" in filters and filters["domains"]:
domains = set(filters["domains"])
exclude_domains = filters.get("exclude_domains", False)
if exclude_domains:
filtered_results = [r for r in filtered_results if r.get("domain", "") not in domains]
else:
filtered_results = [r for r in filtered_results if r.get("domain", "") in domains]
# Filter by minimum score
if "min_score" in filters:
min_score = filters["min_score"]
filtered_results = [r for r in filtered_results if r.get("relevance_score", 0) >= min_score]
# Filter by sources
if "sources" in filters and filters["sources"]:
sources = set(filters["sources"])
filtered_results = [r for r in filtered_results if r.get("source", "") in sources]
# Filter by date range
if "date_range" in filters:
date_range = filters["date_range"]
if "start" in date_range:
start_date = datetime.fromisoformat(date_range["start"])
filtered_results = [
r for r in filtered_results
if "date" not in r or not r["date"] or datetime.fromisoformat(r["date"]) >= start_date
]
if "end" in date_range:
end_date = datetime.fromisoformat(date_range["end"])
filtered_results = [
r for r in filtered_results
if "date" not in r or not r["date"] or datetime.fromisoformat(r["date"]) <= end_date
]
return filtered_results
def group_results_by_domain(self, results: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]:
"""
Group results by domain.
Args:
results: List of search results
Returns:
Dictionary mapping domains to lists of search results
"""
grouped_results = {}
for result in results:
domain = result.get("domain", "unknown")
if domain not in grouped_results:
grouped_results[domain] = []
grouped_results[domain].append(result)
return grouped_results
def save_results(self, results: List[Dict[str, Any]], file_path: str) -> None:
"""
Save search results to a file.
Args:
results: List of search results
file_path: Path to save results to
"""
try:
with open(file_path, 'w') as f:
json.dump(results, f, indent=2)
print(f"Results saved to {file_path}")
except Exception as e:
print(f"Error saving results: {e}")
def load_results(self, file_path: str) -> List[Dict[str, Any]]:
"""
Load search results from a file.
Args:
file_path: Path to load results from
Returns:
List of search results
"""
try:
with open(file_path, 'r') as f:
results = json.load(f)
return results
except Exception as e:
print(f"Error loading results: {e}")
return []