190 lines
7.3 KiB
Python
190 lines
7.3 KiB
Python
"""
|
|
Jina AI Reranker module for the intelligent research system.
|
|
|
|
This module provides functionality to rerank documents based on their relevance
|
|
to a query using Jina AI's Reranker API.
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import requests
|
|
from typing import List, Dict, Any, Optional, Union
|
|
|
|
from config.config import get_config
|
|
|
|
|
|
class JinaReranker:
|
|
"""
|
|
Document reranker using Jina AI's Reranker API.
|
|
|
|
This class provides methods to rerank documents based on their relevance
|
|
to a query, improving the quality of search results.
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initialize the Jina Reranker."""
|
|
self.config = get_config()
|
|
self.api_key = self._get_api_key()
|
|
self.endpoint = "https://api.jina.ai/v1/rerank"
|
|
|
|
# Get reranker configuration
|
|
self.reranker_config = self.config.config_data.get('jina', {}).get('reranker', {})
|
|
self.model = self.reranker_config.get('model', 'jina-reranker-v2-base-multilingual')
|
|
self.default_top_n = self.reranker_config.get('top_n', 10)
|
|
|
|
def _get_api_key(self) -> str:
|
|
"""
|
|
Get the Jina AI API key.
|
|
|
|
Returns:
|
|
The API key as a string
|
|
|
|
Raises:
|
|
ValueError: If the API key is not found
|
|
"""
|
|
try:
|
|
return self.config.get_api_key('jina')
|
|
except ValueError as e:
|
|
raise ValueError(f"Jina AI API key not found. {str(e)}")
|
|
|
|
def rerank(self, query: str, documents: List[str],
|
|
top_n: Optional[int] = None) -> List[Dict[str, Any]]:
|
|
"""
|
|
Rerank documents based on their relevance to the query.
|
|
|
|
Args:
|
|
query: The query to rank documents against
|
|
documents: List of document strings to rerank
|
|
top_n: Number of top results to return (optional)
|
|
|
|
Returns:
|
|
List of dictionaries containing reranked documents with scores and indices
|
|
|
|
Raises:
|
|
Exception: If there's an error calling the Reranker API
|
|
"""
|
|
if not documents:
|
|
return []
|
|
|
|
# Use default top_n if not specified
|
|
if top_n is None:
|
|
top_n = min(self.default_top_n, len(documents))
|
|
else:
|
|
top_n = min(top_n, len(documents))
|
|
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Accept": "application/json"
|
|
}
|
|
|
|
# The correct format is an array of plain strings, not objects with a "text" field
|
|
data = {
|
|
"model": self.model,
|
|
"query": query,
|
|
"documents": documents, # Plain array of strings
|
|
"top_n": top_n
|
|
}
|
|
|
|
print(f"Making reranker API call with query: {query[:50]}... and {len(documents)} documents")
|
|
print(f"Request payload structure: model, query, documents (array of {len(documents)} strings), top_n={top_n}")
|
|
|
|
try:
|
|
response = requests.post(self.endpoint, headers=headers, json=data)
|
|
print(f"Reranker API response status: {response.status_code}")
|
|
|
|
if response.status_code != 200:
|
|
print(f"Reranker API error: {response.text}")
|
|
return []
|
|
|
|
response.raise_for_status() # Raise exception for HTTP errors
|
|
|
|
result = response.json()
|
|
print(f"Reranker API response structure: {list(result.keys())}")
|
|
|
|
# Process and return the reranked results
|
|
reranked_results = []
|
|
|
|
# Check for the specific response structure from the API
|
|
if "results" in result and isinstance(result["results"], list):
|
|
results_list = result["results"]
|
|
for item in results_list:
|
|
if isinstance(item, dict) and "index" in item and "relevance_score" in item:
|
|
reranked_results.append({
|
|
'index': item.get('index'),
|
|
'score': item.get('relevance_score'),
|
|
'document': documents[item.get('index')] if item.get('index') < len(documents) else None
|
|
})
|
|
# Handle newer Jina API format with document.text
|
|
elif isinstance(item, dict) and "index" in item and "document" in item and "relevance_score" in item:
|
|
reranked_results.append({
|
|
'index': item.get('index'),
|
|
'score': item.get('relevance_score'),
|
|
'document': documents[item.get('index')] if item.get('index') < len(documents) else None
|
|
})
|
|
# Fallback for older response structures with "data" field
|
|
elif "data" in result and isinstance(result["data"], list):
|
|
data_list = result["data"]
|
|
for item in data_list:
|
|
if isinstance(item, dict) and "index" in item and "relevance_score" in item:
|
|
reranked_results.append({
|
|
'index': item.get('index'),
|
|
'score': item.get('relevance_score'),
|
|
'document': documents[item.get('index')] if item.get('index') < len(documents) else None
|
|
})
|
|
|
|
print(f"Processed reranker results: {len(reranked_results)} items")
|
|
return reranked_results
|
|
|
|
except Exception as e:
|
|
print(f"Error calling Jina Reranker API: {str(e)}")
|
|
# Return original documents with default ordering in case of error
|
|
return [{'index': i, 'score': 1.0, 'document': doc} for i, doc in enumerate(documents[:top_n])]
|
|
|
|
def rerank_with_metadata(self, query: str, documents: List[Dict[str, Any]],
|
|
document_key: str = 'content',
|
|
top_n: Optional[int] = None) -> List[Dict[str, Any]]:
|
|
"""
|
|
Rerank documents with metadata based on their relevance to the query.
|
|
|
|
Args:
|
|
query: The query to rank documents against
|
|
documents: List of document dictionaries containing content and metadata
|
|
document_key: The key in the document dictionaries that contains the text content
|
|
top_n: Number of top results to return (optional)
|
|
|
|
Returns:
|
|
List of dictionaries containing reranked documents with scores, indices, and original metadata
|
|
|
|
Raises:
|
|
Exception: If there's an error calling the Reranker API
|
|
"""
|
|
if not documents:
|
|
return []
|
|
|
|
# Extract document contents
|
|
doc_contents = [doc.get(document_key, "") for doc in documents]
|
|
|
|
# Rerank the document contents
|
|
reranked_results = self.rerank(query, doc_contents, top_n)
|
|
|
|
# Add original metadata to the results
|
|
for result in reranked_results:
|
|
result['metadata'] = documents[result['index']]
|
|
|
|
return reranked_results
|
|
|
|
|
|
# Create a singleton instance for global use
|
|
jina_reranker = JinaReranker()
|
|
|
|
|
|
def get_jina_reranker() -> JinaReranker:
|
|
"""
|
|
Get the global Jina Reranker instance.
|
|
|
|
Returns:
|
|
JinaReranker instance
|
|
"""
|
|
return jina_reranker
|