ira/jina_similarity.py

113 lines
3.8 KiB
Python

"""
A module for computing text similarity using Jina AI's Embeddings API.
Get your Jina AI API key for free: https://jina.ai/?sui=apikey
The jina-embeddings-v3 model supports input lengths of up to 8,192 tokens.
For longer texts, consider using Jina's Segmenter API to split into smaller chunks.
"""
import os
import requests
import numpy as np
import tiktoken
from typing import Tuple
class TokenLimitError(Exception):
"""Raised when input text exceeds the token limit."""
pass
class JinaSimilarity:
MAX_TOKENS = 8192
def __init__(self):
"""Initialize the JinaSimilarity class."""
self.api_key = os.environ.get("JINA_API_KEY")
if not self.api_key:
raise ValueError("JINA_API_KEY environment variable not set")
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Accept": "application/json",
"Content-Type": "application/json"
}
self.embeddings_url = "https://api.jina.ai/v1/embeddings"
# Initialize tokenizer - using cl100k_base which is used by many modern models
self.tokenizer = tiktoken.get_encoding("cl100k_base")
def count_tokens(self, text: str) -> int:
"""Count the number of tokens in a text.
Args:
text: The text to count tokens for
Returns:
int: Number of tokens in the text
"""
return len(self.tokenizer.encode(text))
def get_embedding(self, text: str) -> list:
"""Get embedding for a piece of text using Jina AI's Embeddings API.
Args:
text: The text to get embeddings for (max 8,192 tokens)
Returns:
list: The embedding vector
Raises:
TokenLimitError: If the text exceeds 8,192 tokens
requests.exceptions.RequestException: If the API call fails
"""
num_tokens = self.count_tokens(text)
if num_tokens > self.MAX_TOKENS:
raise TokenLimitError(
f"Input text is {num_tokens} tokens, which exceeds the maximum of {self.MAX_TOKENS} tokens. "
"Consider using Jina's Segmenter API to split into smaller chunks."
)
payload = {
"model": "jina-embeddings-v3",
"input": [text],
"normalized": True # For cosine similarity
}
response = requests.post(
self.embeddings_url,
headers=self.headers,
json=payload
)
response.raise_for_status()
return response.json()["data"][0]["embedding"]
def compute_similarity(self, chunk: str, query: str) -> Tuple[float, list, list]:
"""Compute similarity between a text chunk and a query.
Args:
chunk: The text chunk to compare against
query: The query text
Returns:
Tuple containing:
- float: Cosine similarity score (0-1)
- list: Chunk embedding
- list: Query embedding
Raises:
TokenLimitError: If the text exceeds 8,192 tokens
requests.exceptions.RequestException: If the API calls fail
"""
# Get embeddings for both texts
chunk_embedding = self.get_embedding(chunk)
query_embedding = self.get_embedding(query)
# Convert to numpy arrays for efficient computation
chunk_vec = np.array(chunk_embedding)
query_vec = np.array(query_embedding)
# Compute cosine similarity
# Since vectors are normalized, dot product equals cosine similarity
similarity = float(np.dot(chunk_vec, query_vec))
return similarity, chunk_embedding, query_embedding