Convert LLM interface methods to async to fix runtime errors

This commit is contained in:
Steve White 2025-02-28 16:44:10 -06:00
parent 0e0d4eb9b2
commit 6b749b9cb6
3 changed files with 67 additions and 56 deletions

View File

@ -8,6 +8,7 @@ enabling query enhancement, classification, and other LLM-powered functionality.
import os import os
import json import json
from typing import Dict, Any, List, Optional, Tuple, Union from typing import Dict, Any, List, Optional, Tuple, Union
import asyncio
import litellm import litellm
from litellm import completion from litellm import completion
@ -198,7 +199,59 @@ class LLMInterface:
# Return error message in a user-friendly format # Return error message in a user-friendly format
return f"I encountered an error while processing your request: {str(e)}" return f"I encountered an error while processing your request: {str(e)}"
def enhance_query(self, query: str) -> str: async def classify_query(self, query: str) -> Dict[str, str]:
"""
Classify a query as factual, exploratory, or comparative.
Args:
query: The query to classify
Returns:
Dictionary with query type and confidence
"""
# Call the async implementation directly
return await self._classify_query_impl(query)
async def _classify_query_impl(self, query: str) -> Dict[str, str]:
"""
Classify a query as factual, exploratory, or comparative.
Args:
query: The query to classify
Returns:
Dictionary with query type and confidence
"""
messages = [
{"role": "system", "content": """You are an expert query classifier.
Analyze the given query and classify it into one of the following types:
- factual: Seeking specific facts or information
- exploratory: Seeking to understand a topic broadly
- comparative: Seeking to compare multiple items or concepts
Respond with a JSON object containing:
- type: The query type (factual, exploratory, or comparative)
- confidence: Your confidence in this classification (high, medium, low)
Example response:
{"type": "exploratory", "confidence": "high"}
"""},
{"role": "user", "content": query}
]
# Generate classification
response = await self.generate_completion(messages)
# Parse JSON response
try:
classification = json.loads(response)
return classification
except json.JSONDecodeError:
# Fallback to default classification if parsing fails
print(f"Error parsing classification response: {response}")
return {"type": "exploratory", "confidence": "low"}
async def enhance_query(self, query: str) -> str:
""" """
Enhance a user query using the LLM. Enhance a user query using the LLM.
@ -214,62 +267,20 @@ class LLMInterface:
# Create a new interface with the assigned model if different from current # Create a new interface with the assigned model if different from current
if model_name != self.model_name: if model_name != self.model_name:
interface = LLMInterface(model_name) interface = LLMInterface(model_name)
return interface._enhance_query_impl(query) return await interface._enhance_query_impl(query)
return self._enhance_query_impl(query) return await self._enhance_query_impl(query)
def _enhance_query_impl(self, query: str) -> str: async def _enhance_query_impl(self, query: str) -> str:
"""Implementation of query enhancement.""" """Implementation of query enhancement."""
messages = [ messages = [
{"role": "system", "content": "You are an AI research assistant. Your task is to enhance the user's query by adding relevant context, clarifying ambiguities, and expanding key terms. Maintain the original intent of the query while making it more comprehensive and precise. Return ONLY the enhanced query text without any explanations, introductions, or additional text. The enhanced query should be ready to be sent directly to a search engine."}, {"role": "system", "content": "You are an AI research assistant. Your task is to enhance the user's query by adding relevant context, clarifying ambiguities, and expanding key terms. Maintain the original intent of the query while making it more comprehensive and precise. Return ONLY the enhanced query text without any explanations, introductions, or additional text. The enhanced query should be ready to be sent directly to a search engine."},
{"role": "user", "content": f"Enhance this research query: {query}"} {"role": "user", "content": f"Enhance this research query: {query}"}
] ]
return self.generate_completion(messages) return await self.generate_completion(messages)
def classify_query(self, query: str) -> Dict[str, Any]: async def generate_search_queries(self, query: str, search_engines: List[str]) -> Dict[str, List[str]]:
"""
Classify a query to determine its type, intent, and key entities.
Args:
query: The user query to classify
Returns:
Dictionary containing query classification information
"""
# Get the model assigned to this specific function
model_name = self.config.get_module_model('query_processing', 'classify_query')
# Create a new interface with the assigned model if different from current
if model_name != self.model_name:
interface = LLMInterface(model_name)
return interface._classify_query_impl(query)
return self._classify_query_impl(query)
def _classify_query_impl(self, query: str) -> Dict[str, Any]:
"""Implementation of query classification."""
messages = [
{"role": "system", "content": "You are an AI research assistant. Analyze the user's query and classify it according to type (factual, exploratory, comparative, etc.), intent, and key entities. Respond with a JSON object containing these classifications."},
{"role": "user", "content": f"Classify this research query: {query}"}
]
response = self.generate_completion(messages)
try:
# Try to parse as JSON
classification = json.loads(response)
return classification
except json.JSONDecodeError:
# If not valid JSON, return a basic classification
return {
"type": "unknown",
"intent": "research",
"entities": [query],
"error": "Failed to parse LLM response as JSON"
}
def generate_search_queries(self, query: str, search_engines: List[str]) -> Dict[str, List[str]]:
""" """
Generate optimized search queries for different search engines. Generate optimized search queries for different search engines.
@ -286,11 +297,11 @@ class LLMInterface:
# Create a new interface with the assigned model if different from current # Create a new interface with the assigned model if different from current
if model_name != self.model_name: if model_name != self.model_name:
interface = LLMInterface(model_name) interface = LLMInterface(model_name)
return interface._generate_search_queries_impl(query, search_engines) return await interface._generate_search_queries_impl(query, search_engines)
return self._generate_search_queries_impl(query, search_engines) return await self._generate_search_queries_impl(query, search_engines)
def _generate_search_queries_impl(self, query: str, search_engines: List[str]) -> Dict[str, List[str]]: async def _generate_search_queries_impl(self, query: str, search_engines: List[str]) -> Dict[str, List[str]]:
"""Implementation of search query generation.""" """Implementation of search query generation."""
engines_str = ", ".join(search_engines) engines_str = ", ".join(search_engines)
@ -299,7 +310,7 @@ class LLMInterface:
{"role": "user", "content": f"Generate optimized search queries for this research topic: {query}"} {"role": "user", "content": f"Generate optimized search queries for this research topic: {query}"}
] ]
response = self.generate_completion(messages) response = await self.generate_completion(messages)
try: try:
# Try to parse as JSON # Try to parse as JSON

View File

@ -22,7 +22,7 @@ class QueryProcessor:
"""Initialize the query processor.""" """Initialize the query processor."""
self.llm_interface = get_llm_interface() self.llm_interface = get_llm_interface()
def process_query(self, query: str) -> Dict[str, Any]: async def process_query(self, query: str) -> Dict[str, Any]:
""" """
Process a user query. Process a user query.
@ -33,10 +33,10 @@ class QueryProcessor:
Dictionary containing the processed query information Dictionary containing the processed query information
""" """
# Enhance the query # Enhance the query
enhanced_query = self.llm_interface.enhance_query(query) enhanced_query = await self.llm_interface.enhance_query(query)
# Classify the query # Classify the query
classification = self.llm_interface.classify_query(query) classification = await self.llm_interface.classify_query(query)
# Extract entities from the classification # Extract entities from the classification
entities = classification.get('entities', []) entities = classification.get('entities', [])

View File

@ -62,7 +62,7 @@ async def query_to_report(
# Step 1: Process the query # Step 1: Process the query
query_processor = get_query_processor() query_processor = get_query_processor()
structured_query = query_processor.process_query(query) structured_query = await query_processor.process_query(query)
# Add timestamp # Add timestamp
structured_query['timestamp'] = datetime.now().isoformat() structured_query['timestamp'] = datetime.now().isoformat()