341 lines
14 KiB
Python
341 lines
14 KiB
Python
"""
|
|
LLM interface module using LiteLLM.
|
|
|
|
This module provides a unified interface to various LLM providers through LiteLLM,
|
|
enabling query enhancement, classification, and other LLM-powered functionality.
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
from typing import Dict, Any, List, Optional, Tuple, Union
|
|
import asyncio
|
|
|
|
import litellm
|
|
from litellm import completion
|
|
|
|
from config.config import get_config
|
|
|
|
|
|
class LLMInterface:
|
|
"""Interface for interacting with LLMs through LiteLLM."""
|
|
|
|
def __init__(self, model_name: Optional[str] = None):
|
|
"""
|
|
Initialize the LLM interface.
|
|
|
|
Args:
|
|
model_name: Name of the LLM model to use. If None, uses the default model
|
|
from configuration.
|
|
"""
|
|
self.config = get_config()
|
|
|
|
# Use specified model or default from config
|
|
self.model_name = model_name or self.config.config_data.get('default_model', 'gpt-3.5-turbo')
|
|
|
|
# Get model-specific configuration
|
|
self.model_config = self.config.get_model_config(self.model_name)
|
|
|
|
# Set up LiteLLM with the appropriate provider
|
|
self._setup_provider()
|
|
|
|
def _setup_provider(self) -> None:
|
|
"""Set up the LLM provider based on the model configuration."""
|
|
provider = self.model_config.get('provider', 'openai')
|
|
|
|
try:
|
|
# Get API key for the provider
|
|
api_key = self.config.get_api_key(provider)
|
|
|
|
# Set environment variable for the provider
|
|
if provider.lower() == 'google' or provider.lower() == 'gemini':
|
|
os.environ["GEMINI_API_KEY"] = api_key
|
|
elif provider.lower() == 'vertex_ai':
|
|
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = api_key
|
|
else:
|
|
os.environ[f"{provider.upper()}_API_KEY"] = api_key
|
|
|
|
print(f"LLM interface initialized with model: {self.model_name} (provider: {provider})")
|
|
except ValueError as e:
|
|
print(f"Error setting up LLM provider: {e}")
|
|
|
|
def _get_completion_params(self) -> Dict[str, Any]:
|
|
"""
|
|
Get parameters for LLM completion based on model configuration.
|
|
|
|
Returns:
|
|
Dictionary of parameters for LiteLLM completion
|
|
"""
|
|
params = {
|
|
'temperature': self.model_config.get('temperature', 0.7),
|
|
'max_tokens': self.model_config.get('max_tokens', 1000),
|
|
'top_p': self.model_config.get('top_p', 1.0)
|
|
}
|
|
|
|
# Handle different provider configurations
|
|
provider = self.model_config.get('provider', 'openai')
|
|
|
|
if provider == 'azure':
|
|
# Azure OpenAI requires special handling
|
|
deployment_name = self.model_config.get('deployment_name')
|
|
api_version = self.model_config.get('api_version')
|
|
endpoint = self.model_config.get('endpoint')
|
|
|
|
if deployment_name and endpoint:
|
|
# Format: azure/deployment_name
|
|
params['model'] = f"azure/{deployment_name}"
|
|
|
|
# Set Azure-specific environment variables if not already set
|
|
if 'AZURE_API_BASE' not in os.environ and endpoint:
|
|
os.environ['AZURE_API_BASE'] = endpoint
|
|
|
|
if 'AZURE_API_VERSION' not in os.environ and api_version:
|
|
os.environ['AZURE_API_VERSION'] = api_version
|
|
else:
|
|
# Fall back to default model if Azure config is incomplete
|
|
params['model'] = self.model_name
|
|
elif provider in ['ollama', 'groq', 'openrouter'] or self.model_config.get('endpoint'):
|
|
# For providers with custom endpoints
|
|
params['model'] = self.model_config.get('model_name', self.model_name)
|
|
params['api_base'] = self.model_config.get('endpoint')
|
|
|
|
# Special handling for OpenRouter
|
|
if provider == 'openrouter':
|
|
# Set HTTP headers for OpenRouter if needed
|
|
params['headers'] = {
|
|
'HTTP-Referer': 'https://sim-search.app', # Replace with your actual app URL
|
|
'X-Title': 'Intelligent Research System' # Replace with your actual app name
|
|
}
|
|
elif provider == 'google' or provider == 'gemini':
|
|
# Special handling for Google Gemini models
|
|
# Format: gemini/model_name (e.g., gemini/gemini-2.0-flash)
|
|
params['model'] = f"gemini/{self.model_config.get('model_name', self.model_name)}"
|
|
|
|
# Add additional parameters for Gemini
|
|
params['custom_llm_provider'] = 'gemini'
|
|
elif provider == 'vertex_ai':
|
|
# Special handling for Vertex AI Gemini models
|
|
params['model'] = f"vertex_ai/{self.model_config.get('model_name', self.model_name)}"
|
|
|
|
# Add Vertex AI specific parameters
|
|
params['vertex_project'] = self.model_config.get('vertex_project', 'sim-search')
|
|
params['vertex_location'] = self.model_config.get('vertex_location', 'us-central1')
|
|
|
|
# Set custom provider
|
|
params['custom_llm_provider'] = 'vertex_ai'
|
|
else:
|
|
# Standard provider (OpenAI, Anthropic, etc.)
|
|
params['model'] = self.model_name
|
|
|
|
return params
|
|
|
|
async def generate_completion(self, messages: List[Dict[str, str]], stream: bool = False) -> Union[str, Any]:
|
|
"""
|
|
Generate a completion using the configured LLM.
|
|
|
|
Args:
|
|
messages: List of message dictionaries with 'role' and 'content' keys
|
|
stream: Whether to stream the response
|
|
|
|
Returns:
|
|
If stream is False, returns the completion text as a string
|
|
If stream is True, returns the completion response object for streaming
|
|
"""
|
|
# Get provider from model config
|
|
provider = self.model_config.get('provider', 'openai').lower()
|
|
|
|
# Special handling for Gemini models - they use 'user' and 'model' roles
|
|
if provider == 'gemini':
|
|
formatted_messages = []
|
|
for msg in messages:
|
|
role = msg['role']
|
|
# Map 'system' to 'user' for the first message
|
|
if role == 'system' and not formatted_messages:
|
|
formatted_messages.append({
|
|
'role': 'user',
|
|
'content': msg['content']
|
|
})
|
|
# Map 'assistant' to 'model'
|
|
elif role == 'assistant':
|
|
formatted_messages.append({
|
|
'role': 'model',
|
|
'content': msg['content']
|
|
})
|
|
# Keep 'user' as is
|
|
else:
|
|
formatted_messages.append(msg)
|
|
else:
|
|
formatted_messages = messages
|
|
|
|
# Get completion parameters
|
|
params = self._get_completion_params()
|
|
|
|
try:
|
|
# Generate completion
|
|
if stream:
|
|
response = litellm.completion(
|
|
messages=formatted_messages,
|
|
stream=True,
|
|
**params
|
|
)
|
|
return response
|
|
else:
|
|
response = litellm.completion(
|
|
messages=formatted_messages,
|
|
**params
|
|
)
|
|
|
|
# Extract content from response
|
|
content = response.choices[0].message.content
|
|
|
|
# Process thinking tags if enabled
|
|
if hasattr(self, 'process_thinking_tags') and self.process_thinking_tags:
|
|
content = self._process_thinking_tags(content)
|
|
|
|
return content
|
|
except Exception as e:
|
|
error_msg = f"Error generating completion: {str(e)}"
|
|
print(error_msg)
|
|
|
|
# Return error message in a user-friendly format
|
|
return f"I encountered an error while processing your request: {str(e)}"
|
|
|
|
async def classify_query(self, query: str) -> Dict[str, str]:
|
|
"""
|
|
Classify a query as factual, exploratory, or comparative.
|
|
|
|
Args:
|
|
query: The query to classify
|
|
|
|
Returns:
|
|
Dictionary with query type and confidence
|
|
"""
|
|
# Call the async implementation directly
|
|
return await self._classify_query_impl(query)
|
|
|
|
async def _classify_query_impl(self, query: str) -> Dict[str, str]:
|
|
"""
|
|
Classify a query as factual, exploratory, or comparative.
|
|
|
|
Args:
|
|
query: The query to classify
|
|
|
|
Returns:
|
|
Dictionary with query type and confidence
|
|
"""
|
|
messages = [
|
|
{"role": "system", "content": """You are an expert query classifier.
|
|
Analyze the given query and classify it into one of the following types:
|
|
- factual: Seeking specific facts or information
|
|
- exploratory: Seeking to understand a topic broadly
|
|
- comparative: Seeking to compare multiple items or concepts
|
|
|
|
Respond with a JSON object containing:
|
|
- type: The query type (factual, exploratory, or comparative)
|
|
- confidence: Your confidence in this classification (high, medium, low)
|
|
|
|
Example response:
|
|
{"type": "exploratory", "confidence": "high"}
|
|
"""},
|
|
{"role": "user", "content": query}
|
|
]
|
|
|
|
# Generate classification
|
|
response = await self.generate_completion(messages)
|
|
|
|
# Parse JSON response
|
|
try:
|
|
classification = json.loads(response)
|
|
return classification
|
|
except json.JSONDecodeError:
|
|
# Fallback to default classification if parsing fails
|
|
print(f"Error parsing classification response: {response}")
|
|
return {"type": "exploratory", "confidence": "low"}
|
|
|
|
async def enhance_query(self, query: str) -> str:
|
|
"""
|
|
Enhance a user query using the LLM.
|
|
|
|
Args:
|
|
query: The raw user query
|
|
|
|
Returns:
|
|
Enhanced query with additional context and structure
|
|
"""
|
|
# Get the model assigned to this specific function
|
|
model_name = self.config.get_module_model('query_processing', 'enhance_query')
|
|
|
|
# Create a new interface with the assigned model if different from current
|
|
if model_name != self.model_name:
|
|
interface = LLMInterface(model_name)
|
|
return await interface._enhance_query_impl(query)
|
|
|
|
return await self._enhance_query_impl(query)
|
|
|
|
async def _enhance_query_impl(self, query: str) -> str:
|
|
"""Implementation of query enhancement."""
|
|
messages = [
|
|
{"role": "system", "content": "You are an AI research assistant. Your task is to enhance the user's query by adding relevant context, clarifying ambiguities, and expanding key terms. Maintain the original intent of the query while making it more comprehensive and precise. Return ONLY the enhanced query text without any explanations, introductions, or additional text. The enhanced query should be ready to be sent directly to a search engine."},
|
|
{"role": "user", "content": f"Enhance this research query: {query}"}
|
|
]
|
|
|
|
return await self.generate_completion(messages)
|
|
|
|
async def generate_search_queries(self, query: str, search_engines: List[str]) -> Dict[str, List[str]]:
|
|
"""
|
|
Generate optimized search queries for different search engines.
|
|
|
|
Args:
|
|
query: The original user query
|
|
search_engines: List of search engines to generate queries for
|
|
|
|
Returns:
|
|
Dictionary mapping search engines to lists of optimized queries
|
|
"""
|
|
# Get the model assigned to this specific function
|
|
model_name = self.config.get_module_model('query_processing', 'generate_search_queries')
|
|
|
|
# Create a new interface with the assigned model if different from current
|
|
if model_name != self.model_name:
|
|
interface = LLMInterface(model_name)
|
|
return await interface._generate_search_queries_impl(query, search_engines)
|
|
|
|
return await self._generate_search_queries_impl(query, search_engines)
|
|
|
|
async def _generate_search_queries_impl(self, query: str, search_engines: List[str]) -> Dict[str, List[str]]:
|
|
"""Implementation of search query generation."""
|
|
engines_str = ", ".join(search_engines)
|
|
|
|
messages = [
|
|
{"role": "system", "content": f"You are an AI research assistant. Generate optimized search queries for the following search engines: {engines_str}. For each search engine, provide 3 variations of the query that are optimized for that engine's search algorithm and will yield comprehensive results."},
|
|
{"role": "user", "content": f"Generate optimized search queries for this research topic: {query}"}
|
|
]
|
|
|
|
response = await self.generate_completion(messages)
|
|
|
|
try:
|
|
# Try to parse as JSON
|
|
queries = json.loads(response)
|
|
return queries
|
|
except json.JSONDecodeError:
|
|
# If not valid JSON, return a basic query set
|
|
return {engine: [query] for engine in search_engines}
|
|
|
|
|
|
# Create a singleton instance for global use
|
|
llm_interface = LLMInterface()
|
|
|
|
|
|
def get_llm_interface(model_name: Optional[str] = None) -> LLMInterface:
|
|
"""
|
|
Get the global LLM interface instance or create a new one with a specific model.
|
|
|
|
Args:
|
|
model_name: Optional model name to use instead of the default
|
|
|
|
Returns:
|
|
LLMInterface instance
|
|
"""
|
|
if model_name:
|
|
return LLMInterface(model_name)
|
|
return llm_interface
|