ira/tests/ranking/test_simple_reranker.py

153 lines
6.0 KiB
Python

import json
import sys
import os
import yaml
from pathlib import Path
# Add the project root to the path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
# Let's create a custom JinaReranker class specifically for testing
class TestJinaReranker:
"""Custom JinaReranker for testing with explicit initialization parameters"""
def __init__(self, api_key, model, endpoint):
"""Initialize with explicit parameters"""
self.api_key = api_key
self.model = model
self.endpoint = endpoint
self.default_top_n = 10
def rerank(self, query, documents, top_n=None):
"""
Rerank documents based on their relevance to the query.
"""
if not documents:
return []
# Use default top_n if not specified
if top_n is None:
top_n = min(self.default_top_n, len(documents))
else:
top_n = min(top_n, len(documents))
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
"Accept": "application/json"
}
data = {
"model": self.model,
"query": query,
"documents": documents, # Plain array of strings
"top_n": top_n
}
print(f"Making reranker API call with query: {query}")
print(f"Request payload structure: model, query, documents (array of {len(documents)} strings), top_n={top_n}")
import requests
try:
response = requests.post(self.endpoint, headers=headers, json=data)
print(f"Reranker API response status: {response.status_code}")
if response.status_code != 200:
print(f"Reranker API error: {response.text}")
return []
response.raise_for_status() # Raise exception for HTTP errors
result = response.json()
print(f"Reranker API response structure: {list(result.keys())}")
print(f"Full response: {json.dumps(result, indent=2)}")
# Process and return the reranked results
reranked_results = []
# Check for the specific response structure from the API
if "results" in result and isinstance(result["results"], list):
results_list = result["results"]
for item in results_list:
if isinstance(item, dict) and "index" in item and "relevance_score" in item:
reranked_results.append({
'index': item.get('index'),
'score': item.get('relevance_score'),
'document': documents[item.get('index')] if item.get('index') < len(documents) else None
})
# Handle newer Jina API format with document.text
elif isinstance(item, dict) and "index" in item and "document" in item and "relevance_score" in item:
reranked_results.append({
'index': item.get('index'),
'score': item.get('relevance_score'),
'document': documents[item.get('index')] if item.get('index') < len(documents) else None
})
# Fallback for older response structures
elif "data" in result and isinstance(result["data"], list):
data_list = result["data"]
for item in data_list:
if isinstance(item, dict) and "index" in item and "relevance_score" in item:
reranked_results.append({
'index': item.get('index'),
'score': item.get('relevance_score'),
'document': documents[item.get('index')] if item.get('index') < len(documents) else None
})
print(f"Processed reranker results: {len(reranked_results)} items")
return reranked_results
except Exception as e:
print(f"Error calling reranker API: {str(e)}")
return []
def load_config():
"""Load configuration from YAML file"""
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "config", "config.yaml")
print(f"Loading config from {config_path}")
if os.path.exists(config_path):
with open(config_path, "r") as f:
config = yaml.safe_load(f)
print("Configuration loaded successfully")
return config
else:
print(f"Config file not found at {config_path}")
return {}
def test_simple_reranker():
"""Test the Jina Reranker with a simple query and documents"""
# Get Jina API key from environment
jina_api_key = os.environ.get("JINA_API_KEY", "")
if not jina_api_key:
print("JINA_API_KEY not found in environment variables")
return
print(f"Found JINA_API_KEY in environment variables")
# Initialize the reranker
reranker = TestJinaReranker(
api_key=jina_api_key,
model="jina-reranker-v2-base-multilingual",
endpoint="https://api.jina.ai/v1/rerank"
)
# Simple query and documents
query = "What is quantum computing?"
documents = [
"Quantum computing is a type of computation that harnesses quantum mechanics.",
"Classical computers use bits, while quantum computers use qubits.",
"Machine learning is a subset of artificial intelligence.",
"Quantum computers can solve certain problems faster than classical computers."
]
print(f"Testing simple reranker with query: {query}")
print(f"Documents: {documents}")
# Rerank the documents
reranked = reranker.rerank(query, documents)
print(f"Reranked results: {json.dumps(reranked, indent=2)}")
if __name__ == "__main__":
# Just run the simple test
test_simple_reranker()