455 lines
17 KiB
Python
455 lines
17 KiB
Python
"""
|
|
Result collector module.
|
|
Processes and organizes search results from multiple search engines.
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import time
|
|
from typing import Dict, List, Any, Optional, Set
|
|
from urllib.parse import urlparse
|
|
from datetime import datetime
|
|
|
|
from ranking.jina_reranker import get_jina_reranker
|
|
|
|
|
|
class ResultCollector:
|
|
"""
|
|
Collects and processes search results from multiple search engines.
|
|
Handles deduplication, merging, and filtering of results.
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initialize the result collector."""
|
|
try:
|
|
self.reranker = get_jina_reranker()
|
|
self.reranker_available = True
|
|
except ValueError:
|
|
print("Jina Reranker not available. Will use basic scoring instead.")
|
|
self.reranker_available = False
|
|
|
|
# Initialize result enrichers
|
|
try:
|
|
from .result_enrichers.unpaywall_enricher import UnpaywallEnricher
|
|
self.unpaywall_enricher = UnpaywallEnricher()
|
|
self.unpaywall_available = True
|
|
except (ImportError, ValueError):
|
|
print("Unpaywall enricher not available. Will not enrich results with open access links.")
|
|
self.unpaywall_available = False
|
|
|
|
def process_results(self,
|
|
search_results: Dict[str, List[Dict[str, Any]]],
|
|
dedup: bool = True,
|
|
max_results: Optional[int] = None,
|
|
use_reranker: bool = True) -> List[Dict[str, Any]]:
|
|
"""
|
|
Process search results from multiple search engines.
|
|
|
|
Args:
|
|
search_results: Dictionary mapping search engine names to lists of results
|
|
dedup: Whether to deduplicate results
|
|
max_results: Maximum number of results to return
|
|
use_reranker: Whether to use the reranker for semantic ranking
|
|
|
|
Returns:
|
|
List of processed search results
|
|
"""
|
|
# Flatten results from all search engines
|
|
flattened_results = []
|
|
for engine, results in search_results.items():
|
|
for result in results:
|
|
# Add the source to each result
|
|
result['source'] = engine
|
|
flattened_results.append(result)
|
|
|
|
# Print a verification of the query in the flattened results
|
|
if flattened_results:
|
|
first_result = flattened_results[0]
|
|
query = first_result.get('query', '')
|
|
print(f"Verifying query in flattened results:")
|
|
print(f"Query in first result: {query[:50]}...")
|
|
|
|
# Deduplicate results if requested
|
|
if dedup:
|
|
flattened_results = self._deduplicate_results(flattened_results)
|
|
|
|
print(f"Processing {len(flattened_results)} combined results")
|
|
if dedup:
|
|
print(f"Deduplicated to {len(flattened_results)} results")
|
|
|
|
# Enrich results with open access links if available
|
|
is_academic_query = any(result.get("source") in ["openalex", "core", "arxiv", "scholar"] for result in flattened_results)
|
|
if is_academic_query and hasattr(self, 'unpaywall_enricher') and self.unpaywall_available:
|
|
print("Enriching academic results with open access information")
|
|
try:
|
|
flattened_results = self.unpaywall_enricher.enrich_results(flattened_results)
|
|
print("Results enriched with open access information")
|
|
except Exception as e:
|
|
print(f"Error enriching results with Unpaywall: {str(e)}")
|
|
|
|
# Apply reranking if requested and available
|
|
if use_reranker and self.reranker is not None:
|
|
print("Using Jina Reranker for semantic ranking")
|
|
try:
|
|
reranked_results = self._rerank_results(flattened_results)
|
|
print(f"Reranked {len(reranked_results)} results")
|
|
processed_results = reranked_results
|
|
except Exception as e:
|
|
print(f"Error during reranking: {str(e)}. Falling back to basic scoring.")
|
|
print("Using basic scoring")
|
|
processed_results = self._score_and_sort_results(flattened_results)
|
|
else:
|
|
print("Using basic scoring")
|
|
processed_results = self._score_and_sort_results(flattened_results)
|
|
|
|
# Limit the number of results if requested
|
|
if max_results is not None:
|
|
processed_results = processed_results[:max_results]
|
|
|
|
print(f"Processed {len(processed_results)} results {'with' if use_reranker and self.reranker is not None else 'without'} reranking")
|
|
return processed_results
|
|
|
|
def _flatten_results(self, search_results: Dict[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
|
|
"""
|
|
Flatten search results from multiple search engines into a single list.
|
|
|
|
Args:
|
|
search_results: Dictionary mapping search engine names to lists of results
|
|
|
|
Returns:
|
|
Flattened list of search results
|
|
"""
|
|
# This method is deprecated and kept for backward compatibility
|
|
# The process_results method now handles flattened results directly
|
|
all_results = []
|
|
|
|
# Check if we have a flattened structure (single key with all results)
|
|
if len(search_results) == 1 and "combined" in search_results:
|
|
return search_results["combined"]
|
|
|
|
# Traditional structure with separate engines
|
|
for engine, results in search_results.items():
|
|
for result in results:
|
|
# Add the source if not already present
|
|
if "source" not in result:
|
|
result["source"] = engine
|
|
all_results.append(result)
|
|
|
|
return all_results
|
|
|
|
def _deduplicate_results(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
"""
|
|
Deduplicate results based on URL.
|
|
|
|
Args:
|
|
results: List of search results
|
|
|
|
Returns:
|
|
Deduplicated list of search results
|
|
"""
|
|
seen_urls = set()
|
|
deduplicated_results = []
|
|
|
|
for result in results:
|
|
url = result.get("url", "")
|
|
|
|
# Normalize URL for comparison
|
|
normalized_url = self._normalize_url(url)
|
|
|
|
if normalized_url and normalized_url not in seen_urls:
|
|
seen_urls.add(normalized_url)
|
|
deduplicated_results.append(result)
|
|
|
|
return deduplicated_results
|
|
|
|
def _score_and_sort_results(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
"""
|
|
Score and sort results by relevance.
|
|
|
|
Args:
|
|
results: List of search results
|
|
|
|
Returns:
|
|
Sorted list of search results
|
|
"""
|
|
# Add a score to each result
|
|
for result in results:
|
|
score = 0
|
|
|
|
# Boost score based on source (e.g., scholarly sources get higher scores)
|
|
source = result.get("source", "")
|
|
if source == "scholar":
|
|
score += 10
|
|
elif source == "openalex":
|
|
score += 10 # Top priority for academic queries
|
|
elif source == "core":
|
|
score += 9 # High priority for open access academic content
|
|
elif source == "arxiv":
|
|
score += 8 # Good for preprints and specific fields
|
|
elif source == "github":
|
|
score += 9 # High priority for code/programming queries
|
|
elif source.startswith("stackexchange"):
|
|
score += 10 # Top priority for code/programming questions
|
|
elif source == "serper":
|
|
score += 7 # General web search
|
|
elif source == "news":
|
|
score += 8 # Good for current events
|
|
elif source == "google":
|
|
score += 5 # Generic search
|
|
|
|
# Boost score based on position in original results
|
|
position = result.get("raw_data", {}).get("position", 0)
|
|
if position > 0:
|
|
score += max(0, 10 - position)
|
|
|
|
# Boost score for results with more content
|
|
snippet_length = len(result.get("snippet", ""))
|
|
if snippet_length > 200:
|
|
score += 3
|
|
elif snippet_length > 100:
|
|
score += 2
|
|
elif snippet_length > 50:
|
|
score += 1
|
|
|
|
# Store the score
|
|
result["relevance_score"] = score
|
|
|
|
# Sort by score (descending)
|
|
sorted_results = sorted(results, key=lambda x: x.get("relevance_score", 0), reverse=True)
|
|
|
|
return sorted_results
|
|
|
|
def _rerank_results(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
"""
|
|
Rerank results using the Jina Reranker.
|
|
|
|
Args:
|
|
results: List of search results
|
|
|
|
Returns:
|
|
Reranked list of search results
|
|
"""
|
|
if not results:
|
|
return []
|
|
|
|
# Extract the original query
|
|
# First try to get it from the first result
|
|
query = ""
|
|
for result in results:
|
|
if "query" in result:
|
|
query = result.get("query", "")
|
|
break
|
|
|
|
if not query:
|
|
# If no query is found, use a fallback approach
|
|
print("Warning: No query found in results. Using basic scoring instead.")
|
|
return self._score_and_sort_results(results)
|
|
|
|
print(f"Reranking with query: {query}")
|
|
|
|
# Extract snippets for reranking
|
|
snippets = []
|
|
for result in results:
|
|
# Combine title and snippet for better reranking
|
|
content = f"{result.get('title', '')} {result.get('snippet', '')}"
|
|
snippets.append(content)
|
|
|
|
try:
|
|
# Use the reranker to rerank the snippets
|
|
reranked = self.reranker.rerank(query, snippets)
|
|
|
|
if not reranked:
|
|
print("Reranker returned empty results. Using basic scoring instead.")
|
|
return self._score_and_sort_results(results)
|
|
|
|
print(f"Reranked {len(reranked)} results")
|
|
|
|
# Create a new list of results based on the reranking
|
|
reranked_results = []
|
|
for item in reranked:
|
|
# Get the original result and add the new score
|
|
index = item.get('index')
|
|
score = item.get('score')
|
|
|
|
if index is None or score is None or index >= len(results):
|
|
print(f"Warning: Invalid reranker result item: {item}")
|
|
continue
|
|
|
|
original_result = results[index]
|
|
new_result = original_result.copy()
|
|
new_result['relevance_score'] = float(score) * 10 # Scale up the score for consistency
|
|
reranked_results.append(new_result)
|
|
|
|
# If we didn't get any valid results, fall back to basic scoring
|
|
if not reranked_results:
|
|
print("No valid reranked results. Using basic scoring instead.")
|
|
return self._score_and_sort_results(results)
|
|
|
|
return reranked_results
|
|
except Exception as e:
|
|
print(f"Error reranking results: {str(e)}")
|
|
# Fall back to basic scoring if reranking fails
|
|
return self._score_and_sort_results(results)
|
|
|
|
def _extract_domain(self, url: str) -> str:
|
|
"""
|
|
Extract the domain from a URL.
|
|
|
|
Args:
|
|
url: URL to extract domain from
|
|
|
|
Returns:
|
|
Domain name
|
|
"""
|
|
try:
|
|
parsed_url = urlparse(url)
|
|
domain = parsed_url.netloc
|
|
|
|
# Remove 'www.' prefix if present
|
|
if domain.startswith('www.'):
|
|
domain = domain[4:]
|
|
|
|
return domain
|
|
except:
|
|
return ""
|
|
|
|
def _normalize_url(self, url: str) -> str:
|
|
"""
|
|
Normalize a URL for comparison.
|
|
|
|
Args:
|
|
url: URL to normalize
|
|
|
|
Returns:
|
|
Normalized URL
|
|
"""
|
|
try:
|
|
# Parse the URL
|
|
parsed_url = urlparse(url)
|
|
|
|
# Reconstruct with just the scheme, netloc, and path
|
|
normalized = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path}"
|
|
|
|
# Remove trailing slash if present
|
|
if normalized.endswith('/'):
|
|
normalized = normalized[:-1]
|
|
|
|
return normalized.lower()
|
|
except:
|
|
return url.lower()
|
|
|
|
def filter_results(self,
|
|
results: List[Dict[str, Any]],
|
|
filters: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
"""
|
|
Filter results based on specified criteria.
|
|
|
|
Args:
|
|
results: List of search results
|
|
filters: Dictionary of filter criteria:
|
|
- domains: List of domains to include or exclude
|
|
- exclude_domains: Whether to exclude (True) or include (False) the specified domains
|
|
- min_score: Minimum relevance score
|
|
- sources: List of sources to include
|
|
- date_range: Dictionary with 'start' and 'end' dates
|
|
|
|
Returns:
|
|
Filtered list of search results
|
|
"""
|
|
filtered_results = results.copy()
|
|
|
|
# Filter by domains
|
|
if "domains" in filters and filters["domains"]:
|
|
domains = set(filters["domains"])
|
|
exclude_domains = filters.get("exclude_domains", False)
|
|
|
|
if exclude_domains:
|
|
filtered_results = [r for r in filtered_results if r.get("domain", "") not in domains]
|
|
else:
|
|
filtered_results = [r for r in filtered_results if r.get("domain", "") in domains]
|
|
|
|
# Filter by minimum score
|
|
if "min_score" in filters:
|
|
min_score = filters["min_score"]
|
|
filtered_results = [r for r in filtered_results if r.get("relevance_score", 0) >= min_score]
|
|
|
|
# Filter by sources
|
|
if "sources" in filters and filters["sources"]:
|
|
sources = set(filters["sources"])
|
|
filtered_results = [r for r in filtered_results if r.get("source", "") in sources]
|
|
|
|
# Filter by date range
|
|
if "date_range" in filters:
|
|
date_range = filters["date_range"]
|
|
|
|
if "start" in date_range:
|
|
start_date = datetime.fromisoformat(date_range["start"])
|
|
filtered_results = [
|
|
r for r in filtered_results
|
|
if "date" not in r or not r["date"] or datetime.fromisoformat(r["date"]) >= start_date
|
|
]
|
|
|
|
if "end" in date_range:
|
|
end_date = datetime.fromisoformat(date_range["end"])
|
|
filtered_results = [
|
|
r for r in filtered_results
|
|
if "date" not in r or not r["date"] or datetime.fromisoformat(r["date"]) <= end_date
|
|
]
|
|
|
|
return filtered_results
|
|
|
|
def group_results_by_domain(self, results: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]:
|
|
"""
|
|
Group results by domain.
|
|
|
|
Args:
|
|
results: List of search results
|
|
|
|
Returns:
|
|
Dictionary mapping domains to lists of search results
|
|
"""
|
|
grouped_results = {}
|
|
|
|
for result in results:
|
|
domain = result.get("domain", "unknown")
|
|
|
|
if domain not in grouped_results:
|
|
grouped_results[domain] = []
|
|
|
|
grouped_results[domain].append(result)
|
|
|
|
return grouped_results
|
|
|
|
def save_results(self, results: List[Dict[str, Any]], file_path: str) -> None:
|
|
"""
|
|
Save search results to a file.
|
|
|
|
Args:
|
|
results: List of search results
|
|
file_path: Path to save results to
|
|
"""
|
|
try:
|
|
with open(file_path, 'w') as f:
|
|
json.dump(results, f, indent=2)
|
|
print(f"Results saved to {file_path}")
|
|
except Exception as e:
|
|
print(f"Error saving results: {e}")
|
|
|
|
def load_results(self, file_path: str) -> List[Dict[str, Any]]:
|
|
"""
|
|
Load search results from a file.
|
|
|
|
Args:
|
|
file_path: Path to load results from
|
|
|
|
Returns:
|
|
List of search results
|
|
"""
|
|
try:
|
|
with open(file_path, 'r') as f:
|
|
results = json.load(f)
|
|
return results
|
|
except Exception as e:
|
|
print(f"Error loading results: {e}")
|
|
return []
|