113 lines
3.8 KiB
Python
113 lines
3.8 KiB
Python
"""
|
|
A module for computing text similarity using Jina AI's Embeddings API.
|
|
Get your Jina AI API key for free: https://jina.ai/?sui=apikey
|
|
|
|
The jina-embeddings-v3 model supports input lengths of up to 8,192 tokens.
|
|
For longer texts, consider using Jina's Segmenter API to split into smaller chunks.
|
|
"""
|
|
|
|
import os
|
|
import requests
|
|
import numpy as np
|
|
import tiktoken
|
|
from typing import Tuple
|
|
|
|
class TokenLimitError(Exception):
|
|
"""Raised when input text exceeds the token limit."""
|
|
pass
|
|
|
|
class JinaSimilarity:
|
|
MAX_TOKENS = 8192
|
|
|
|
def __init__(self):
|
|
"""Initialize the JinaSimilarity class."""
|
|
self.api_key = os.environ.get("JINA_API_KEY")
|
|
if not self.api_key:
|
|
raise ValueError("JINA_API_KEY environment variable not set")
|
|
|
|
self.headers = {
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Accept": "application/json",
|
|
"Content-Type": "application/json"
|
|
}
|
|
self.embeddings_url = "https://api.jina.ai/v1/embeddings"
|
|
# Initialize tokenizer - using cl100k_base which is used by many modern models
|
|
self.tokenizer = tiktoken.get_encoding("cl100k_base")
|
|
|
|
def count_tokens(self, text: str) -> int:
|
|
"""Count the number of tokens in a text.
|
|
|
|
Args:
|
|
text: The text to count tokens for
|
|
|
|
Returns:
|
|
int: Number of tokens in the text
|
|
"""
|
|
return len(self.tokenizer.encode(text))
|
|
|
|
def get_embedding(self, text: str) -> list:
|
|
"""Get embedding for a piece of text using Jina AI's Embeddings API.
|
|
|
|
Args:
|
|
text: The text to get embeddings for (max 8,192 tokens)
|
|
|
|
Returns:
|
|
list: The embedding vector
|
|
|
|
Raises:
|
|
TokenLimitError: If the text exceeds 8,192 tokens
|
|
requests.exceptions.RequestException: If the API call fails
|
|
"""
|
|
num_tokens = self.count_tokens(text)
|
|
if num_tokens > self.MAX_TOKENS:
|
|
raise TokenLimitError(
|
|
f"Input text is {num_tokens} tokens, which exceeds the maximum of {self.MAX_TOKENS} tokens. "
|
|
"Consider using Jina's Segmenter API to split into smaller chunks."
|
|
)
|
|
|
|
payload = {
|
|
"model": "jina-embeddings-v3",
|
|
"input": [text],
|
|
"normalized": True # For cosine similarity
|
|
}
|
|
|
|
response = requests.post(
|
|
self.embeddings_url,
|
|
headers=self.headers,
|
|
json=payload
|
|
)
|
|
response.raise_for_status()
|
|
|
|
return response.json()["data"][0]["embedding"]
|
|
|
|
def compute_similarity(self, chunk: str, query: str) -> Tuple[float, list, list]:
|
|
"""Compute similarity between a text chunk and a query.
|
|
|
|
Args:
|
|
chunk: The text chunk to compare against
|
|
query: The query text
|
|
|
|
Returns:
|
|
Tuple containing:
|
|
- float: Cosine similarity score (0-1)
|
|
- list: Chunk embedding
|
|
- list: Query embedding
|
|
|
|
Raises:
|
|
TokenLimitError: If the text exceeds 8,192 tokens
|
|
requests.exceptions.RequestException: If the API calls fail
|
|
"""
|
|
# Get embeddings for both texts
|
|
chunk_embedding = self.get_embedding(chunk)
|
|
query_embedding = self.get_embedding(query)
|
|
|
|
# Convert to numpy arrays for efficient computation
|
|
chunk_vec = np.array(chunk_embedding)
|
|
query_vec = np.array(query_embedding)
|
|
|
|
# Compute cosine similarity
|
|
# Since vectors are normalized, dot product equals cosine similarity
|
|
similarity = float(np.dot(chunk_vec, query_vec))
|
|
|
|
return similarity, chunk_embedding, query_embedding
|