Convert LLM interface methods to async to fix runtime errors
This commit is contained in:
parent
0e0d4eb9b2
commit
6b749b9cb6
|
@ -8,6 +8,7 @@ enabling query enhancement, classification, and other LLM-powered functionality.
|
|||
import os
|
||||
import json
|
||||
from typing import Dict, Any, List, Optional, Tuple, Union
|
||||
import asyncio
|
||||
|
||||
import litellm
|
||||
from litellm import completion
|
||||
|
@ -198,7 +199,59 @@ class LLMInterface:
|
|||
# Return error message in a user-friendly format
|
||||
return f"I encountered an error while processing your request: {str(e)}"
|
||||
|
||||
def enhance_query(self, query: str) -> str:
|
||||
async def classify_query(self, query: str) -> Dict[str, str]:
|
||||
"""
|
||||
Classify a query as factual, exploratory, or comparative.
|
||||
|
||||
Args:
|
||||
query: The query to classify
|
||||
|
||||
Returns:
|
||||
Dictionary with query type and confidence
|
||||
"""
|
||||
# Call the async implementation directly
|
||||
return await self._classify_query_impl(query)
|
||||
|
||||
async def _classify_query_impl(self, query: str) -> Dict[str, str]:
|
||||
"""
|
||||
Classify a query as factual, exploratory, or comparative.
|
||||
|
||||
Args:
|
||||
query: The query to classify
|
||||
|
||||
Returns:
|
||||
Dictionary with query type and confidence
|
||||
"""
|
||||
messages = [
|
||||
{"role": "system", "content": """You are an expert query classifier.
|
||||
Analyze the given query and classify it into one of the following types:
|
||||
- factual: Seeking specific facts or information
|
||||
- exploratory: Seeking to understand a topic broadly
|
||||
- comparative: Seeking to compare multiple items or concepts
|
||||
|
||||
Respond with a JSON object containing:
|
||||
- type: The query type (factual, exploratory, or comparative)
|
||||
- confidence: Your confidence in this classification (high, medium, low)
|
||||
|
||||
Example response:
|
||||
{"type": "exploratory", "confidence": "high"}
|
||||
"""},
|
||||
{"role": "user", "content": query}
|
||||
]
|
||||
|
||||
# Generate classification
|
||||
response = await self.generate_completion(messages)
|
||||
|
||||
# Parse JSON response
|
||||
try:
|
||||
classification = json.loads(response)
|
||||
return classification
|
||||
except json.JSONDecodeError:
|
||||
# Fallback to default classification if parsing fails
|
||||
print(f"Error parsing classification response: {response}")
|
||||
return {"type": "exploratory", "confidence": "low"}
|
||||
|
||||
async def enhance_query(self, query: str) -> str:
|
||||
"""
|
||||
Enhance a user query using the LLM.
|
||||
|
||||
|
@ -214,62 +267,20 @@ class LLMInterface:
|
|||
# Create a new interface with the assigned model if different from current
|
||||
if model_name != self.model_name:
|
||||
interface = LLMInterface(model_name)
|
||||
return interface._enhance_query_impl(query)
|
||||
return await interface._enhance_query_impl(query)
|
||||
|
||||
return self._enhance_query_impl(query)
|
||||
return await self._enhance_query_impl(query)
|
||||
|
||||
def _enhance_query_impl(self, query: str) -> str:
|
||||
async def _enhance_query_impl(self, query: str) -> str:
|
||||
"""Implementation of query enhancement."""
|
||||
messages = [
|
||||
{"role": "system", "content": "You are an AI research assistant. Your task is to enhance the user's query by adding relevant context, clarifying ambiguities, and expanding key terms. Maintain the original intent of the query while making it more comprehensive and precise. Return ONLY the enhanced query text without any explanations, introductions, or additional text. The enhanced query should be ready to be sent directly to a search engine."},
|
||||
{"role": "user", "content": f"Enhance this research query: {query}"}
|
||||
]
|
||||
|
||||
return self.generate_completion(messages)
|
||||
return await self.generate_completion(messages)
|
||||
|
||||
def classify_query(self, query: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Classify a query to determine its type, intent, and key entities.
|
||||
|
||||
Args:
|
||||
query: The user query to classify
|
||||
|
||||
Returns:
|
||||
Dictionary containing query classification information
|
||||
"""
|
||||
# Get the model assigned to this specific function
|
||||
model_name = self.config.get_module_model('query_processing', 'classify_query')
|
||||
|
||||
# Create a new interface with the assigned model if different from current
|
||||
if model_name != self.model_name:
|
||||
interface = LLMInterface(model_name)
|
||||
return interface._classify_query_impl(query)
|
||||
|
||||
return self._classify_query_impl(query)
|
||||
|
||||
def _classify_query_impl(self, query: str) -> Dict[str, Any]:
|
||||
"""Implementation of query classification."""
|
||||
messages = [
|
||||
{"role": "system", "content": "You are an AI research assistant. Analyze the user's query and classify it according to type (factual, exploratory, comparative, etc.), intent, and key entities. Respond with a JSON object containing these classifications."},
|
||||
{"role": "user", "content": f"Classify this research query: {query}"}
|
||||
]
|
||||
|
||||
response = self.generate_completion(messages)
|
||||
|
||||
try:
|
||||
# Try to parse as JSON
|
||||
classification = json.loads(response)
|
||||
return classification
|
||||
except json.JSONDecodeError:
|
||||
# If not valid JSON, return a basic classification
|
||||
return {
|
||||
"type": "unknown",
|
||||
"intent": "research",
|
||||
"entities": [query],
|
||||
"error": "Failed to parse LLM response as JSON"
|
||||
}
|
||||
|
||||
def generate_search_queries(self, query: str, search_engines: List[str]) -> Dict[str, List[str]]:
|
||||
async def generate_search_queries(self, query: str, search_engines: List[str]) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Generate optimized search queries for different search engines.
|
||||
|
||||
|
@ -286,11 +297,11 @@ class LLMInterface:
|
|||
# Create a new interface with the assigned model if different from current
|
||||
if model_name != self.model_name:
|
||||
interface = LLMInterface(model_name)
|
||||
return interface._generate_search_queries_impl(query, search_engines)
|
||||
return await interface._generate_search_queries_impl(query, search_engines)
|
||||
|
||||
return self._generate_search_queries_impl(query, search_engines)
|
||||
return await self._generate_search_queries_impl(query, search_engines)
|
||||
|
||||
def _generate_search_queries_impl(self, query: str, search_engines: List[str]) -> Dict[str, List[str]]:
|
||||
async def _generate_search_queries_impl(self, query: str, search_engines: List[str]) -> Dict[str, List[str]]:
|
||||
"""Implementation of search query generation."""
|
||||
engines_str = ", ".join(search_engines)
|
||||
|
||||
|
@ -299,7 +310,7 @@ class LLMInterface:
|
|||
{"role": "user", "content": f"Generate optimized search queries for this research topic: {query}"}
|
||||
]
|
||||
|
||||
response = self.generate_completion(messages)
|
||||
response = await self.generate_completion(messages)
|
||||
|
||||
try:
|
||||
# Try to parse as JSON
|
||||
|
|
|
@ -22,7 +22,7 @@ class QueryProcessor:
|
|||
"""Initialize the query processor."""
|
||||
self.llm_interface = get_llm_interface()
|
||||
|
||||
def process_query(self, query: str) -> Dict[str, Any]:
|
||||
async def process_query(self, query: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Process a user query.
|
||||
|
||||
|
@ -33,10 +33,10 @@ class QueryProcessor:
|
|||
Dictionary containing the processed query information
|
||||
"""
|
||||
# Enhance the query
|
||||
enhanced_query = self.llm_interface.enhance_query(query)
|
||||
enhanced_query = await self.llm_interface.enhance_query(query)
|
||||
|
||||
# Classify the query
|
||||
classification = self.llm_interface.classify_query(query)
|
||||
classification = await self.llm_interface.classify_query(query)
|
||||
|
||||
# Extract entities from the classification
|
||||
entities = classification.get('entities', [])
|
||||
|
|
|
@ -62,7 +62,7 @@ async def query_to_report(
|
|||
|
||||
# Step 1: Process the query
|
||||
query_processor = get_query_processor()
|
||||
structured_query = query_processor.process_query(query)
|
||||
structured_query = await query_processor.process_query(query)
|
||||
|
||||
# Add timestamp
|
||||
structured_query['timestamp'] = datetime.now().isoformat()
|
||||
|
|
Loading…
Reference in New Issue