468 lines
20 KiB
Python
468 lines
20 KiB
Python
"""
|
|
LLM interface module using LiteLLM.
|
|
|
|
This module provides a unified interface to various LLM providers through LiteLLM,
|
|
enabling query enhancement, classification, and other LLM-powered functionality.
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
from typing import Dict, Any, List, Optional, Tuple, Union
|
|
import asyncio
|
|
|
|
import litellm
|
|
from litellm import completion
|
|
|
|
from config.config import get_config
|
|
|
|
|
|
class LLMInterface:
|
|
"""Interface for interacting with LLMs through LiteLLM."""
|
|
|
|
def __init__(self, model_name: Optional[str] = None):
|
|
"""
|
|
Initialize the LLM interface.
|
|
|
|
Args:
|
|
model_name: Name of the LLM model to use. If None, uses the default model
|
|
from configuration.
|
|
"""
|
|
self.config = get_config()
|
|
|
|
# Use specified model or default from config
|
|
self.model_name = model_name or self.config.config_data.get('default_model', 'gpt-3.5-turbo')
|
|
|
|
# Get model-specific configuration
|
|
self.model_config = self.config.get_model_config(self.model_name)
|
|
|
|
# Set up LiteLLM with the appropriate provider
|
|
self._setup_provider()
|
|
|
|
def _setup_provider(self) -> None:
|
|
"""Set up the LLM provider based on the model configuration."""
|
|
provider = self.model_config.get('provider', 'openai')
|
|
|
|
try:
|
|
# Get API key for the provider
|
|
api_key = self.config.get_api_key(provider)
|
|
|
|
# Set environment variable for the provider
|
|
if provider.lower() == 'google' or provider.lower() == 'gemini':
|
|
os.environ["GEMINI_API_KEY"] = api_key
|
|
elif provider.lower() == 'vertex_ai':
|
|
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = api_key
|
|
else:
|
|
os.environ[f"{provider.upper()}_API_KEY"] = api_key
|
|
|
|
print(f"LLM interface initialized with model: {self.model_name} (provider: {provider})")
|
|
except ValueError as e:
|
|
print(f"Error setting up LLM provider: {e}")
|
|
|
|
def _get_completion_params(self) -> Dict[str, Any]:
|
|
"""
|
|
Get parameters for LLM completion based on model configuration.
|
|
|
|
Returns:
|
|
Dictionary of parameters for LiteLLM completion
|
|
"""
|
|
params = {
|
|
'temperature': self.model_config.get('temperature', 0.7),
|
|
'max_tokens': self.model_config.get('max_tokens', 1000),
|
|
'top_p': self.model_config.get('top_p', 1.0)
|
|
}
|
|
|
|
# Handle different provider configurations
|
|
provider = self.model_config.get('provider', 'openai')
|
|
|
|
if provider == 'azure':
|
|
# Azure OpenAI requires special handling
|
|
deployment_name = self.model_config.get('deployment_name')
|
|
api_version = self.model_config.get('api_version')
|
|
endpoint = self.model_config.get('endpoint')
|
|
|
|
if deployment_name and endpoint:
|
|
# Format: azure/deployment_name
|
|
params['model'] = f"azure/{deployment_name}"
|
|
|
|
# Set Azure-specific environment variables if not already set
|
|
if 'AZURE_API_BASE' not in os.environ and endpoint:
|
|
os.environ['AZURE_API_BASE'] = endpoint
|
|
|
|
if 'AZURE_API_VERSION' not in os.environ and api_version:
|
|
os.environ['AZURE_API_VERSION'] = api_version
|
|
else:
|
|
# Fall back to default model if Azure config is incomplete
|
|
params['model'] = self.model_name
|
|
elif provider in ['ollama', 'groq', 'openrouter'] or self.model_config.get('endpoint'):
|
|
# For providers with custom endpoints
|
|
params['model'] = self.model_config.get('model_name', self.model_name)
|
|
params['api_base'] = self.model_config.get('endpoint')
|
|
|
|
# Special handling for OpenRouter
|
|
if provider == 'openrouter':
|
|
# Set HTTP headers for OpenRouter if needed
|
|
params['headers'] = {
|
|
'HTTP-Referer': 'https://sim-search.app', # Replace with your actual app URL
|
|
'X-Title': 'Intelligent Research System' # Replace with your actual app name
|
|
}
|
|
elif provider == 'google' or provider == 'gemini':
|
|
# Special handling for Google Gemini models
|
|
# Format: gemini/model_name (e.g., gemini/gemini-2.0-flash)
|
|
params['model'] = f"gemini/{self.model_config.get('model_name', self.model_name)}"
|
|
|
|
# Add additional parameters for Gemini
|
|
params['custom_llm_provider'] = 'gemini'
|
|
elif provider == 'vertex_ai':
|
|
# Special handling for Vertex AI Gemini models
|
|
params['model'] = f"vertex_ai/{self.model_config.get('model_name', self.model_name)}"
|
|
|
|
# Add Vertex AI specific parameters
|
|
params['vertex_project'] = self.model_config.get('vertex_project', 'sim-search')
|
|
params['vertex_location'] = self.model_config.get('vertex_location', 'us-central1')
|
|
|
|
# Set custom provider
|
|
params['custom_llm_provider'] = 'vertex_ai'
|
|
else:
|
|
# Standard provider (OpenAI, Anthropic, etc.)
|
|
params['model'] = self.model_name
|
|
|
|
return params
|
|
|
|
async def generate_completion(self, messages: List[Dict[str, str]], stream: bool = False) -> Union[str, Any]:
|
|
"""
|
|
Generate a completion using the configured LLM.
|
|
|
|
Args:
|
|
messages: List of message dictionaries with 'role' and 'content' keys
|
|
stream: Whether to stream the response
|
|
|
|
Returns:
|
|
If stream is False, returns the completion text as a string
|
|
If stream is True, returns the completion response object for streaming
|
|
"""
|
|
# Get provider from model config
|
|
provider = self.model_config.get('provider', 'openai').lower()
|
|
|
|
# Special handling for Gemini models - they use 'user' and 'model' roles
|
|
if provider == 'gemini':
|
|
formatted_messages = []
|
|
for msg in messages:
|
|
role = msg['role']
|
|
# Map 'system' to 'user' for the first message
|
|
if role == 'system' and not formatted_messages:
|
|
formatted_messages.append({
|
|
'role': 'user',
|
|
'content': msg['content']
|
|
})
|
|
# Map 'assistant' to 'model'
|
|
elif role == 'assistant':
|
|
formatted_messages.append({
|
|
'role': 'model',
|
|
'content': msg['content']
|
|
})
|
|
# Keep 'user' as is
|
|
else:
|
|
formatted_messages.append(msg)
|
|
else:
|
|
formatted_messages = messages
|
|
|
|
# Get completion parameters
|
|
params = self._get_completion_params()
|
|
|
|
try:
|
|
# Generate completion
|
|
if stream:
|
|
response = litellm.completion(
|
|
messages=formatted_messages,
|
|
stream=True,
|
|
**params
|
|
)
|
|
return response
|
|
else:
|
|
response = litellm.completion(
|
|
messages=formatted_messages,
|
|
**params
|
|
)
|
|
|
|
# Extract content from response
|
|
content = response.choices[0].message.content
|
|
|
|
# Process thinking tags if enabled
|
|
if hasattr(self, 'process_thinking_tags') and self.process_thinking_tags:
|
|
content = self._process_thinking_tags(content)
|
|
|
|
return content
|
|
except Exception as e:
|
|
error_msg = f"Error generating completion: {str(e)}"
|
|
print(error_msg)
|
|
|
|
# Return error message in a user-friendly format
|
|
return f"I encountered an error while processing your request: {str(e)}"
|
|
|
|
async def classify_query_domain(self, query: str) -> Dict[str, Any]:
|
|
"""
|
|
Classify a query's domain type (academic, code, current_events, general).
|
|
|
|
Args:
|
|
query: The query to classify
|
|
|
|
Returns:
|
|
Dictionary with query domain type and confidence scores
|
|
"""
|
|
# Get the model assigned to this function
|
|
model_name = self.config.get_module_model('query_processing', 'classify_query_domain')
|
|
|
|
# Create a new interface with the assigned model if different from current
|
|
if model_name != self.model_name:
|
|
interface = LLMInterface(model_name)
|
|
return await interface._classify_query_domain_impl(query)
|
|
|
|
return await self._classify_query_domain_impl(query)
|
|
|
|
async def _classify_query_domain_impl(self, query: str) -> Dict[str, Any]:
|
|
"""Implementation of query domain classification."""
|
|
messages = [
|
|
{"role": "system", "content": """You are an expert query classifier.
|
|
Analyze the given query and classify it into the following domain types:
|
|
- academic: Related to scholarly research, scientific studies, academic papers, formal theories, university-level research topics, or scholarly fields of study
|
|
- code: Related to programming, software development, technical implementation, coding languages, frameworks, or technology implementation questions
|
|
- current_events: Related to recent news, ongoing developments, time-sensitive information, current politics, breaking stories, or real-time events
|
|
- general: General information seeking that doesn't fit the above categories
|
|
|
|
You may assign multiple types if the query spans several domains.
|
|
|
|
Respond with a JSON object containing:
|
|
{
|
|
"primary_type": "the most appropriate type",
|
|
"confidence": 0.X,
|
|
"secondary_types": [{"type": "another_applicable_type", "confidence": 0.X}, ...],
|
|
"reasoning": "brief explanation of your classification"
|
|
}
|
|
"""},
|
|
{"role": "user", "content": query}
|
|
]
|
|
|
|
# Generate classification
|
|
response = await self.generate_completion(messages)
|
|
|
|
# Parse JSON response
|
|
try:
|
|
classification = json.loads(response)
|
|
return classification
|
|
except json.JSONDecodeError:
|
|
# Fallback to default classification if parsing fails
|
|
print(f"Error parsing domain classification response: {response}")
|
|
return {
|
|
"primary_type": "general",
|
|
"confidence": 0.5,
|
|
"secondary_types": [],
|
|
"reasoning": "Failed to parse classification response"
|
|
}
|
|
|
|
async def classify_query(self, query: str) -> Dict[str, str]:
|
|
"""
|
|
Classify a query as factual, exploratory, or comparative.
|
|
|
|
Args:
|
|
query: The query to classify
|
|
|
|
Returns:
|
|
Dictionary with query type and confidence
|
|
"""
|
|
# Call the async implementation directly
|
|
return await self._classify_query_impl(query)
|
|
|
|
async def _classify_query_impl(self, query: str) -> Dict[str, str]:
|
|
"""
|
|
Classify a query as factual, exploratory, or comparative.
|
|
|
|
Args:
|
|
query: The query to classify
|
|
|
|
Returns:
|
|
Dictionary with query type and confidence
|
|
"""
|
|
messages = [
|
|
{"role": "system", "content": """You are an expert query classifier.
|
|
Analyze the given query and classify it into one of the following types:
|
|
- factual: Seeking specific facts or information
|
|
- exploratory: Seeking to understand a topic broadly
|
|
- comparative: Seeking to compare multiple items or concepts
|
|
|
|
Respond with a JSON object containing:
|
|
- type: The query type (factual, exploratory, or comparative)
|
|
- confidence: Your confidence in this classification (high, medium, low)
|
|
|
|
Example response:
|
|
{"type": "exploratory", "confidence": "high"}
|
|
"""},
|
|
{"role": "user", "content": query}
|
|
]
|
|
|
|
# Generate classification
|
|
response = await self.generate_completion(messages)
|
|
|
|
# Parse JSON response
|
|
try:
|
|
classification = json.loads(response)
|
|
return classification
|
|
except json.JSONDecodeError:
|
|
# Fallback to default classification if parsing fails
|
|
print(f"Error parsing classification response: {response}")
|
|
return {"type": "exploratory", "confidence": "low"}
|
|
|
|
async def enhance_query(self, query: str) -> str:
|
|
"""
|
|
Enhance a user query using the LLM.
|
|
|
|
Args:
|
|
query: The raw user query
|
|
|
|
Returns:
|
|
Enhanced query with additional context and structure
|
|
"""
|
|
# Get the model assigned to this specific function
|
|
model_name = self.config.get_module_model('query_processing', 'enhance_query')
|
|
|
|
# Create a new interface with the assigned model if different from current
|
|
if model_name != self.model_name:
|
|
interface = LLMInterface(model_name)
|
|
return await interface._enhance_query_impl(query)
|
|
|
|
return await self._enhance_query_impl(query)
|
|
|
|
async def _enhance_query_impl(self, query: str) -> str:
|
|
"""Implementation of query enhancement."""
|
|
messages = [
|
|
{"role": "system", "content": "You are an AI research assistant. Your task is to enhance the user's query by adding relevant context, clarifying ambiguities, and expanding key terms. Maintain the original intent of the query while making it more comprehensive and precise. Return ONLY the enhanced query text without any explanations, introductions, or additional text. The enhanced query should be ready to be sent directly to a search engine."},
|
|
{"role": "user", "content": f"Enhance this research query: {query}"}
|
|
]
|
|
|
|
return await self.generate_completion(messages)
|
|
|
|
async def generate_search_queries(self, query: str, search_engines: List[str]) -> Dict[str, List[str]]:
|
|
"""
|
|
Generate optimized search queries for different search engines.
|
|
|
|
Args:
|
|
query: The original user query
|
|
search_engines: List of search engines to generate queries for
|
|
|
|
Returns:
|
|
Dictionary mapping search engines to lists of optimized queries
|
|
"""
|
|
# Get the model assigned to this specific function
|
|
model_name = self.config.get_module_model('query_processing', 'generate_search_queries')
|
|
|
|
# Create a new interface with the assigned model if different from current
|
|
if model_name != self.model_name:
|
|
interface = LLMInterface(model_name)
|
|
return await interface._generate_search_queries_impl(query, search_engines)
|
|
|
|
return await self._generate_search_queries_impl(query, search_engines)
|
|
|
|
async def _generate_search_queries_impl(self, query: str, search_engines: List[str]) -> Dict[str, List[str]]:
|
|
"""Implementation of search query generation."""
|
|
engines_str = ", ".join(search_engines)
|
|
|
|
# Special instructions for news searches
|
|
news_instructions = ""
|
|
if "news" in search_engines:
|
|
news_instructions = """
|
|
For the 'news' search engine:
|
|
- Focus on recent events and timely information
|
|
- Include specific date ranges when relevant (e.g., "last week", "since June 1")
|
|
- Use names of people, organizations, or specific events
|
|
- For current events queries, prioritize factual keywords over conceptual terms
|
|
- Include terms like "latest", "recent", "update", "announcement" where appropriate
|
|
- Exclude general background terms that would dilute current event focus
|
|
- Generate 3 queries optimized for news search
|
|
"""
|
|
|
|
# Special instructions for academic searches
|
|
academic_instructions = ""
|
|
if any(engine in search_engines for engine in ["openalex", "core", "arxiv"]):
|
|
academic_instructions = """
|
|
For academic search engines ('openalex', 'core', 'arxiv'):
|
|
- Focus on specific academic terminology and precise research concepts
|
|
- Include field-specific keywords and methodological terms
|
|
- For 'openalex' search:
|
|
- Include author names, journal names, or specific methodology terms when relevant
|
|
- Be precise with scientific terminology
|
|
- Consider including "review" or "meta-analysis" for summary-type queries
|
|
- For 'core' search:
|
|
- Focus on open access content
|
|
- Include institutional keywords when relevant
|
|
- Balance specificity with breadth
|
|
- For 'arxiv' search:
|
|
- Use more technical/mathematical terminology
|
|
- Include relevant field categories (e.g., "cs.AI", "physics", "math")
|
|
- Be precise with notation and specialized terms
|
|
- Generate 3 queries optimized for each academic search engine
|
|
"""
|
|
|
|
# Special instructions for code/programming searches
|
|
code_instructions = ""
|
|
if any(engine in search_engines for engine in ["github", "stackexchange"]):
|
|
code_instructions = """
|
|
For code/programming search engines ('github', 'stackexchange'):
|
|
- Focus on specific technical terminology, programming languages, and frameworks
|
|
- Include specific error messages, function names, or library references when relevant
|
|
- For 'github' search:
|
|
- Include programming language keywords (e.g., "python", "javascript", "java")
|
|
- Specify file extensions when relevant (e.g., ".py", ".js", ".java")
|
|
- Include framework or library names (e.g., "react", "tensorflow", "django")
|
|
- Use code-specific syntax and terminology
|
|
- Focus on implementation details, patterns, or techniques
|
|
- For 'stackexchange' search:
|
|
- Phrase as a specific programming question or problem
|
|
- Include relevant error messages as exact quotes when applicable
|
|
- Include specific version information when relevant
|
|
- Use precise technical terms that would appear in developer discussions
|
|
- Focus on problem-solving aspects or best practices
|
|
- Generate 3 queries optimized for each code search engine
|
|
"""
|
|
|
|
messages = [
|
|
{"role": "system", "content": f"""You are an AI research assistant. Generate optimized search queries for the following search engines: {engines_str}.
|
|
|
|
For each search engine, provide 3 variations of the query that are optimized for that engine's search algorithm and will yield comprehensive results.
|
|
|
|
{news_instructions}
|
|
{academic_instructions}
|
|
{code_instructions}
|
|
|
|
Return your response as a JSON object where each key is a search engine name and the value is an array of 3 optimized queries.
|
|
"""},
|
|
{"role": "user", "content": f"Generate optimized search queries for this research topic: {query}"}
|
|
]
|
|
|
|
response = await self.generate_completion(messages)
|
|
|
|
try:
|
|
# Try to parse as JSON
|
|
queries = json.loads(response)
|
|
return queries
|
|
except json.JSONDecodeError:
|
|
# If not valid JSON, return a basic query set
|
|
return {engine: [query] for engine in search_engines}
|
|
|
|
|
|
# Create a singleton instance for global use
|
|
llm_interface = LLMInterface()
|
|
|
|
|
|
def get_llm_interface(model_name: Optional[str] = None) -> LLMInterface:
|
|
"""
|
|
Get the global LLM interface instance or create a new one with a specific model.
|
|
|
|
Args:
|
|
model_name: Optional model name to use instead of the default
|
|
|
|
Returns:
|
|
LLMInterface instance
|
|
"""
|
|
if model_name:
|
|
return LLMInterface(model_name)
|
|
return llm_interface
|