ira/query/llm_interface.py

468 lines
20 KiB
Python

"""
LLM interface module using LiteLLM.
This module provides a unified interface to various LLM providers through LiteLLM,
enabling query enhancement, classification, and other LLM-powered functionality.
"""
import os
import json
from typing import Dict, Any, List, Optional, Tuple, Union
import asyncio
import litellm
from litellm import completion
from config.config import get_config
class LLMInterface:
"""Interface for interacting with LLMs through LiteLLM."""
def __init__(self, model_name: Optional[str] = None):
"""
Initialize the LLM interface.
Args:
model_name: Name of the LLM model to use. If None, uses the default model
from configuration.
"""
self.config = get_config()
# Use specified model or default from config
self.model_name = model_name or self.config.config_data.get('default_model', 'gpt-3.5-turbo')
# Get model-specific configuration
self.model_config = self.config.get_model_config(self.model_name)
# Set up LiteLLM with the appropriate provider
self._setup_provider()
def _setup_provider(self) -> None:
"""Set up the LLM provider based on the model configuration."""
provider = self.model_config.get('provider', 'openai')
try:
# Get API key for the provider
api_key = self.config.get_api_key(provider)
# Set environment variable for the provider
if provider.lower() == 'google' or provider.lower() == 'gemini':
os.environ["GEMINI_API_KEY"] = api_key
elif provider.lower() == 'vertex_ai':
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = api_key
else:
os.environ[f"{provider.upper()}_API_KEY"] = api_key
print(f"LLM interface initialized with model: {self.model_name} (provider: {provider})")
except ValueError as e:
print(f"Error setting up LLM provider: {e}")
def _get_completion_params(self) -> Dict[str, Any]:
"""
Get parameters for LLM completion based on model configuration.
Returns:
Dictionary of parameters for LiteLLM completion
"""
params = {
'temperature': self.model_config.get('temperature', 0.7),
'max_tokens': self.model_config.get('max_tokens', 1000),
'top_p': self.model_config.get('top_p', 1.0)
}
# Handle different provider configurations
provider = self.model_config.get('provider', 'openai')
if provider == 'azure':
# Azure OpenAI requires special handling
deployment_name = self.model_config.get('deployment_name')
api_version = self.model_config.get('api_version')
endpoint = self.model_config.get('endpoint')
if deployment_name and endpoint:
# Format: azure/deployment_name
params['model'] = f"azure/{deployment_name}"
# Set Azure-specific environment variables if not already set
if 'AZURE_API_BASE' not in os.environ and endpoint:
os.environ['AZURE_API_BASE'] = endpoint
if 'AZURE_API_VERSION' not in os.environ and api_version:
os.environ['AZURE_API_VERSION'] = api_version
else:
# Fall back to default model if Azure config is incomplete
params['model'] = self.model_name
elif provider in ['ollama', 'groq', 'openrouter'] or self.model_config.get('endpoint'):
# For providers with custom endpoints
params['model'] = self.model_config.get('model_name', self.model_name)
params['api_base'] = self.model_config.get('endpoint')
# Special handling for OpenRouter
if provider == 'openrouter':
# Set HTTP headers for OpenRouter if needed
params['headers'] = {
'HTTP-Referer': 'https://sim-search.app', # Replace with your actual app URL
'X-Title': 'Intelligent Research System' # Replace with your actual app name
}
elif provider == 'google' or provider == 'gemini':
# Special handling for Google Gemini models
# Format: gemini/model_name (e.g., gemini/gemini-2.0-flash)
params['model'] = f"gemini/{self.model_config.get('model_name', self.model_name)}"
# Add additional parameters for Gemini
params['custom_llm_provider'] = 'gemini'
elif provider == 'vertex_ai':
# Special handling for Vertex AI Gemini models
params['model'] = f"vertex_ai/{self.model_config.get('model_name', self.model_name)}"
# Add Vertex AI specific parameters
params['vertex_project'] = self.model_config.get('vertex_project', 'sim-search')
params['vertex_location'] = self.model_config.get('vertex_location', 'us-central1')
# Set custom provider
params['custom_llm_provider'] = 'vertex_ai'
else:
# Standard provider (OpenAI, Anthropic, etc.)
params['model'] = self.model_name
return params
async def generate_completion(self, messages: List[Dict[str, str]], stream: bool = False) -> Union[str, Any]:
"""
Generate a completion using the configured LLM.
Args:
messages: List of message dictionaries with 'role' and 'content' keys
stream: Whether to stream the response
Returns:
If stream is False, returns the completion text as a string
If stream is True, returns the completion response object for streaming
"""
# Get provider from model config
provider = self.model_config.get('provider', 'openai').lower()
# Special handling for Gemini models - they use 'user' and 'model' roles
if provider == 'gemini':
formatted_messages = []
for msg in messages:
role = msg['role']
# Map 'system' to 'user' for the first message
if role == 'system' and not formatted_messages:
formatted_messages.append({
'role': 'user',
'content': msg['content']
})
# Map 'assistant' to 'model'
elif role == 'assistant':
formatted_messages.append({
'role': 'model',
'content': msg['content']
})
# Keep 'user' as is
else:
formatted_messages.append(msg)
else:
formatted_messages = messages
# Get completion parameters
params = self._get_completion_params()
try:
# Generate completion
if stream:
response = litellm.completion(
messages=formatted_messages,
stream=True,
**params
)
return response
else:
response = litellm.completion(
messages=formatted_messages,
**params
)
# Extract content from response
content = response.choices[0].message.content
# Process thinking tags if enabled
if hasattr(self, 'process_thinking_tags') and self.process_thinking_tags:
content = self._process_thinking_tags(content)
return content
except Exception as e:
error_msg = f"Error generating completion: {str(e)}"
print(error_msg)
# Return error message in a user-friendly format
return f"I encountered an error while processing your request: {str(e)}"
async def classify_query_domain(self, query: str) -> Dict[str, Any]:
"""
Classify a query's domain type (academic, code, current_events, general).
Args:
query: The query to classify
Returns:
Dictionary with query domain type and confidence scores
"""
# Get the model assigned to this function
model_name = self.config.get_module_model('query_processing', 'classify_query_domain')
# Create a new interface with the assigned model if different from current
if model_name != self.model_name:
interface = LLMInterface(model_name)
return await interface._classify_query_domain_impl(query)
return await self._classify_query_domain_impl(query)
async def _classify_query_domain_impl(self, query: str) -> Dict[str, Any]:
"""Implementation of query domain classification."""
messages = [
{"role": "system", "content": """You are an expert query classifier.
Analyze the given query and classify it into the following domain types:
- academic: Related to scholarly research, scientific studies, academic papers, formal theories, university-level research topics, or scholarly fields of study
- code: Related to programming, software development, technical implementation, coding languages, frameworks, or technology implementation questions
- current_events: Related to recent news, ongoing developments, time-sensitive information, current politics, breaking stories, or real-time events
- general: General information seeking that doesn't fit the above categories
You may assign multiple types if the query spans several domains.
Respond with a JSON object containing:
{
"primary_type": "the most appropriate type",
"confidence": 0.X,
"secondary_types": [{"type": "another_applicable_type", "confidence": 0.X}, ...],
"reasoning": "brief explanation of your classification"
}
"""},
{"role": "user", "content": query}
]
# Generate classification
response = await self.generate_completion(messages)
# Parse JSON response
try:
classification = json.loads(response)
return classification
except json.JSONDecodeError:
# Fallback to default classification if parsing fails
print(f"Error parsing domain classification response: {response}")
return {
"primary_type": "general",
"confidence": 0.5,
"secondary_types": [],
"reasoning": "Failed to parse classification response"
}
async def classify_query(self, query: str) -> Dict[str, str]:
"""
Classify a query as factual, exploratory, or comparative.
Args:
query: The query to classify
Returns:
Dictionary with query type and confidence
"""
# Call the async implementation directly
return await self._classify_query_impl(query)
async def _classify_query_impl(self, query: str) -> Dict[str, str]:
"""
Classify a query as factual, exploratory, or comparative.
Args:
query: The query to classify
Returns:
Dictionary with query type and confidence
"""
messages = [
{"role": "system", "content": """You are an expert query classifier.
Analyze the given query and classify it into one of the following types:
- factual: Seeking specific facts or information
- exploratory: Seeking to understand a topic broadly
- comparative: Seeking to compare multiple items or concepts
Respond with a JSON object containing:
- type: The query type (factual, exploratory, or comparative)
- confidence: Your confidence in this classification (high, medium, low)
Example response:
{"type": "exploratory", "confidence": "high"}
"""},
{"role": "user", "content": query}
]
# Generate classification
response = await self.generate_completion(messages)
# Parse JSON response
try:
classification = json.loads(response)
return classification
except json.JSONDecodeError:
# Fallback to default classification if parsing fails
print(f"Error parsing classification response: {response}")
return {"type": "exploratory", "confidence": "low"}
async def enhance_query(self, query: str) -> str:
"""
Enhance a user query using the LLM.
Args:
query: The raw user query
Returns:
Enhanced query with additional context and structure
"""
# Get the model assigned to this specific function
model_name = self.config.get_module_model('query_processing', 'enhance_query')
# Create a new interface with the assigned model if different from current
if model_name != self.model_name:
interface = LLMInterface(model_name)
return await interface._enhance_query_impl(query)
return await self._enhance_query_impl(query)
async def _enhance_query_impl(self, query: str) -> str:
"""Implementation of query enhancement."""
messages = [
{"role": "system", "content": "You are an AI research assistant. Your task is to enhance the user's query by adding relevant context, clarifying ambiguities, and expanding key terms. Maintain the original intent of the query while making it more comprehensive and precise. Return ONLY the enhanced query text without any explanations, introductions, or additional text. The enhanced query should be ready to be sent directly to a search engine."},
{"role": "user", "content": f"Enhance this research query: {query}"}
]
return await self.generate_completion(messages)
async def generate_search_queries(self, query: str, search_engines: List[str]) -> Dict[str, List[str]]:
"""
Generate optimized search queries for different search engines.
Args:
query: The original user query
search_engines: List of search engines to generate queries for
Returns:
Dictionary mapping search engines to lists of optimized queries
"""
# Get the model assigned to this specific function
model_name = self.config.get_module_model('query_processing', 'generate_search_queries')
# Create a new interface with the assigned model if different from current
if model_name != self.model_name:
interface = LLMInterface(model_name)
return await interface._generate_search_queries_impl(query, search_engines)
return await self._generate_search_queries_impl(query, search_engines)
async def _generate_search_queries_impl(self, query: str, search_engines: List[str]) -> Dict[str, List[str]]:
"""Implementation of search query generation."""
engines_str = ", ".join(search_engines)
# Special instructions for news searches
news_instructions = ""
if "news" in search_engines:
news_instructions = """
For the 'news' search engine:
- Focus on recent events and timely information
- Include specific date ranges when relevant (e.g., "last week", "since June 1")
- Use names of people, organizations, or specific events
- For current events queries, prioritize factual keywords over conceptual terms
- Include terms like "latest", "recent", "update", "announcement" where appropriate
- Exclude general background terms that would dilute current event focus
- Generate 3 queries optimized for news search
"""
# Special instructions for academic searches
academic_instructions = ""
if any(engine in search_engines for engine in ["openalex", "core", "arxiv"]):
academic_instructions = """
For academic search engines ('openalex', 'core', 'arxiv'):
- Focus on specific academic terminology and precise research concepts
- Include field-specific keywords and methodological terms
- For 'openalex' search:
- Include author names, journal names, or specific methodology terms when relevant
- Be precise with scientific terminology
- Consider including "review" or "meta-analysis" for summary-type queries
- For 'core' search:
- Focus on open access content
- Include institutional keywords when relevant
- Balance specificity with breadth
- For 'arxiv' search:
- Use more technical/mathematical terminology
- Include relevant field categories (e.g., "cs.AI", "physics", "math")
- Be precise with notation and specialized terms
- Generate 3 queries optimized for each academic search engine
"""
# Special instructions for code/programming searches
code_instructions = ""
if any(engine in search_engines for engine in ["github", "stackexchange"]):
code_instructions = """
For code/programming search engines ('github', 'stackexchange'):
- Focus on specific technical terminology, programming languages, and frameworks
- Include specific error messages, function names, or library references when relevant
- For 'github' search:
- Include programming language keywords (e.g., "python", "javascript", "java")
- Specify file extensions when relevant (e.g., ".py", ".js", ".java")
- Include framework or library names (e.g., "react", "tensorflow", "django")
- Use code-specific syntax and terminology
- Focus on implementation details, patterns, or techniques
- For 'stackexchange' search:
- Phrase as a specific programming question or problem
- Include relevant error messages as exact quotes when applicable
- Include specific version information when relevant
- Use precise technical terms that would appear in developer discussions
- Focus on problem-solving aspects or best practices
- Generate 3 queries optimized for each code search engine
"""
messages = [
{"role": "system", "content": f"""You are an AI research assistant. Generate optimized search queries for the following search engines: {engines_str}.
For each search engine, provide 3 variations of the query that are optimized for that engine's search algorithm and will yield comprehensive results.
{news_instructions}
{academic_instructions}
{code_instructions}
Return your response as a JSON object where each key is a search engine name and the value is an array of 3 optimized queries.
"""},
{"role": "user", "content": f"Generate optimized search queries for this research topic: {query}"}
]
response = await self.generate_completion(messages)
try:
# Try to parse as JSON
queries = json.loads(response)
return queries
except json.JSONDecodeError:
# If not valid JSON, return a basic query set
return {engine: [query] for engine in search_engines}
# Create a singleton instance for global use
llm_interface = LLMInterface()
def get_llm_interface(model_name: Optional[str] = None) -> LLMInterface:
"""
Get the global LLM interface instance or create a new one with a specific model.
Args:
model_name: Optional model name to use instead of the default
Returns:
LLMInterface instance
"""
if model_name:
return LLMInterface(model_name)
return llm_interface