ira/tests/ranking/test_similarity.py

100 lines
2.7 KiB
Python
Executable File

#!/usr/bin/env python3
"""
Test script for the JinaSimilarity module.
Computes similarity between text from two input files.
"""
import argparse
import sys
from pathlib import Path
from jina_similarity import JinaSimilarity, TokenLimitError
def read_file(file_path: str) -> str:
"""Read content from a file.
Args:
file_path: Path to the file to read
Returns:
str: Content of the file
Raises:
FileNotFoundError: If the file doesn't exist
"""
with open(file_path, 'r', encoding='utf-8') as f:
return f.read().strip()
def main():
parser = argparse.ArgumentParser(
description='Compute similarity between text from two files using Jina AI.'
)
parser.add_argument(
'chunk_file',
type=str,
help='Path to the file containing the text chunk'
)
parser.add_argument(
'query_file',
type=str,
help='Path to the file containing the query'
)
parser.add_argument(
'--verbose',
'-v',
action='store_true',
help='Print token counts and embeddings'
)
args = parser.parse_args()
# Check if files exist
chunk_path = Path(args.chunk_file)
query_path = Path(args.query_file)
if not chunk_path.is_file():
print(f"Error: Chunk file not found: {args.chunk_file}", file=sys.stderr)
sys.exit(1)
if not query_path.is_file():
print(f"Error: Query file not found: {args.query_file}", file=sys.stderr)
sys.exit(1)
try:
# Read input files
chunk_text = read_file(args.chunk_file)
query_text = read_file(args.query_file)
# Initialize similarity module
js = JinaSimilarity()
# Get token counts if verbose
if args.verbose:
chunk_tokens = js.count_tokens(chunk_text)
query_tokens = js.count_tokens(query_text)
print(f"\nToken counts:")
print(f"Chunk: {chunk_tokens} tokens")
print(f"Query: {query_tokens} tokens\n")
# Compute similarity
similarity, chunk_embedding, query_embedding = js.compute_similarity(
chunk_text,
query_text
)
# Print results
print(f"Similarity score: {similarity:.4f}")
if args.verbose:
print(f"\nEmbeddings:")
print(f"Chunk embedding (first 5): {chunk_embedding[:5]}...")
print(f"Query embedding (first 5): {query_embedding[:5]}...")
except TokenLimitError as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)
if __name__ == '__main__':
main()