391 lines
17 KiB
Python
391 lines
17 KiB
Python
"""
|
|
Query processor module for the intelligent research system.
|
|
|
|
This module handles the processing of user queries, including enhancement,
|
|
classification, decomposition, and structuring for downstream modules.
|
|
"""
|
|
|
|
from typing import Dict, Any, List, Optional
|
|
import logging
|
|
|
|
from .llm_interface import get_llm_interface
|
|
from .query_decomposer import get_query_decomposer
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class QueryProcessor:
|
|
"""
|
|
Processor for user research queries.
|
|
|
|
This class handles the processing of user queries, including enhancement,
|
|
classification, and structuring for downstream modules.
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initialize the query processor."""
|
|
self.llm_interface = get_llm_interface()
|
|
self.query_decomposer = get_query_decomposer()
|
|
|
|
async def process_query(self, query: str) -> Dict[str, Any]:
|
|
"""
|
|
Process a user query.
|
|
|
|
Args:
|
|
query: The raw user query
|
|
|
|
Returns:
|
|
Dictionary containing the processed query information
|
|
"""
|
|
logger.info(f"Processing query: {query}")
|
|
|
|
# Enhance the query
|
|
enhanced_query = await self.llm_interface.enhance_query(query)
|
|
logger.info(f"Enhanced query: {enhanced_query}")
|
|
|
|
# Classify the query type (factual, exploratory, comparative)
|
|
query_type_classification = await self.llm_interface.classify_query(query)
|
|
logger.info(f"Query type classification: {query_type_classification}")
|
|
|
|
# Classify the query domain (academic, code, current_events, general)
|
|
domain_classification = await self.llm_interface.classify_query_domain(query)
|
|
logger.info(f"Query domain classification: {domain_classification}")
|
|
|
|
# Log classification details for monitoring
|
|
if domain_classification.get('secondary_types'):
|
|
for sec_type in domain_classification.get('secondary_types'):
|
|
logger.info(f"Secondary domain: {sec_type['type']} confidence={sec_type['confidence']}")
|
|
logger.info(f"Classification reasoning: {domain_classification.get('reasoning', 'None provided')}")
|
|
|
|
try:
|
|
# Structure the query using the new classification approach
|
|
structured_query = self._structure_query_with_llm(query, enhanced_query, query_type_classification, domain_classification)
|
|
except Exception as e:
|
|
logger.error(f"LLM domain classification failed: {e}. Falling back to keyword-based classification.")
|
|
# Fallback to keyword-based approach
|
|
structured_query = self._structure_query(query, enhanced_query, query_type_classification)
|
|
|
|
# Decompose the query into sub-questions (if complex enough)
|
|
structured_query = await self.query_decomposer.decompose_query(query, structured_query)
|
|
|
|
# Log the number of sub-questions if any
|
|
if 'sub_questions' in structured_query and structured_query['sub_questions']:
|
|
logger.info(f"Decomposed into {len(structured_query['sub_questions'])} sub-questions")
|
|
else:
|
|
logger.info("Query was not decomposed into sub-questions")
|
|
|
|
return structured_query
|
|
|
|
def _structure_query_with_llm(self, original_query: str, enhanced_query: str,
|
|
type_classification: Dict[str, Any],
|
|
domain_classification: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""
|
|
Structure a query using LLM classification results.
|
|
|
|
Args:
|
|
original_query: The original user query
|
|
enhanced_query: The enhanced query
|
|
type_classification: Classification of query type (factual, exploratory, comparative)
|
|
domain_classification: Classification of query domain (academic, code, current_events)
|
|
|
|
Returns:
|
|
Dictionary containing the structured query
|
|
"""
|
|
# Get primary domain and confidence
|
|
primary_domain = domain_classification.get('primary_type', 'general')
|
|
primary_confidence = domain_classification.get('confidence', 0.5)
|
|
|
|
# Get secondary domains
|
|
secondary_domains = domain_classification.get('secondary_types', [])
|
|
|
|
# Determine domain flags
|
|
is_academic = primary_domain == 'academic' or any(d['type'] == 'academic' for d in secondary_domains)
|
|
is_code = primary_domain == 'code' or any(d['type'] == 'code' for d in secondary_domains)
|
|
is_current_events = primary_domain == 'current_events' or any(d['type'] == 'current_events' for d in secondary_domains)
|
|
|
|
# Higher threshold for secondary domains to avoid false positives
|
|
if primary_domain != 'academic' and any(d['type'] == 'academic' and d['confidence'] >= 0.3 for d in secondary_domains):
|
|
is_academic = True
|
|
|
|
if primary_domain != 'code' and any(d['type'] == 'code' and d['confidence'] >= 0.3 for d in secondary_domains):
|
|
is_code = True
|
|
|
|
if primary_domain != 'current_events' and any(d['type'] == 'current_events' and d['confidence'] >= 0.3 for d in secondary_domains):
|
|
is_current_events = True
|
|
|
|
return {
|
|
'original_query': original_query,
|
|
'enhanced_query': enhanced_query,
|
|
'type': type_classification.get('type', 'unknown'),
|
|
'intent': type_classification.get('intent', 'research'),
|
|
'entities': type_classification.get('entities', []),
|
|
'domain': primary_domain,
|
|
'domain_confidence': primary_confidence,
|
|
'secondary_domains': secondary_domains,
|
|
'classification_reasoning': domain_classification.get('reasoning', ''),
|
|
'timestamp': None, # Will be filled in by the caller
|
|
'is_current_events': is_current_events,
|
|
'is_academic': is_academic,
|
|
'is_code': is_code,
|
|
'metadata': {
|
|
'type_classification': type_classification,
|
|
'domain_classification': domain_classification
|
|
}
|
|
}
|
|
|
|
def _structure_query(self, original_query: str, enhanced_query: str,
|
|
classification: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""
|
|
Structure a query for downstream modules using keyword-based classification.
|
|
This is a fallback method when LLM classification fails.
|
|
|
|
Args:
|
|
original_query: The original user query
|
|
enhanced_query: The enhanced query
|
|
classification: The query classification
|
|
|
|
Returns:
|
|
Dictionary containing the structured query
|
|
"""
|
|
# Detect query types using keyword-based methods
|
|
is_current_events = self._is_current_events_query(original_query, classification)
|
|
is_academic = self._is_academic_query(original_query, classification)
|
|
is_code = self._is_code_query(original_query, classification)
|
|
|
|
return {
|
|
'original_query': original_query,
|
|
'enhanced_query': enhanced_query,
|
|
'type': classification.get('type', 'unknown'),
|
|
'intent': classification.get('intent', 'research'),
|
|
'entities': classification.get('entities', []),
|
|
'timestamp': None, # Will be filled in by the caller
|
|
'is_current_events': is_current_events,
|
|
'is_academic': is_academic,
|
|
'is_code': is_code,
|
|
'metadata': {
|
|
'classification': classification,
|
|
'classification_method': 'keyword' # Indicate this used the keyword-based method
|
|
}
|
|
}
|
|
|
|
def _is_current_events_query(self, query: str, classification: Dict[str, Any]) -> bool:
|
|
"""
|
|
Determine if a query is related to current events.
|
|
|
|
Args:
|
|
query: The original user query
|
|
classification: The query classification
|
|
|
|
Returns:
|
|
True if the query is about current events, False otherwise
|
|
"""
|
|
# Check for time-related keywords in the query
|
|
time_keywords = ['recent', 'latest', 'current', 'today', 'yesterday', 'week', 'month',
|
|
'this year', 'breaking', 'news', 'announced', 'election',
|
|
'now', 'trends', 'emerging']
|
|
|
|
query_lower = query.lower()
|
|
|
|
# Check for named entities typical of current events
|
|
current_event_entities = ['trump', 'biden', 'president', 'government', 'congress',
|
|
'senate', 'tariffs', 'election', 'policy', 'coronavirus',
|
|
'covid', 'market', 'stocks', 'stock market', 'war']
|
|
|
|
# Count matches for time keywords
|
|
time_keyword_count = sum(1 for keyword in time_keywords if keyword in query_lower)
|
|
|
|
# Count matches for current event entities
|
|
entity_count = sum(1 for entity in current_event_entities if entity in query_lower)
|
|
|
|
# If the query directly asks about what's happening or what happened
|
|
action_verbs = ['happen', 'occurred', 'announced', 'said', 'stated', 'declared', 'launched']
|
|
verb_matches = sum(1 for verb in action_verbs if verb in query_lower)
|
|
|
|
# Determine if this is likely a current events query
|
|
# Either multiple time keywords or current event entities, or a combination
|
|
is_current = (time_keyword_count >= 1 and entity_count >= 1) or time_keyword_count >= 2 or entity_count >= 2 or verb_matches >= 1
|
|
|
|
return is_current
|
|
|
|
def _is_academic_query(self, query: str, classification: Dict[str, Any]) -> bool:
|
|
"""
|
|
Determine if a query is related to academic or scholarly research.
|
|
|
|
Args:
|
|
query: The original user query
|
|
classification: The query classification
|
|
|
|
Returns:
|
|
True if the query is about academic research, False otherwise
|
|
"""
|
|
query_lower = query.lower()
|
|
|
|
# Check for academic terms
|
|
academic_terms = [
|
|
'paper', 'study', 'research', 'publication', 'journal', 'article', 'thesis',
|
|
'dissertation', 'scholarly', 'academic', 'literature', 'published', 'author',
|
|
'citation', 'cited', 'references', 'bibliography', 'doi', 'peer-reviewed',
|
|
'peer reviewed', 'university', 'professor', 'conference', 'proceedings'
|
|
]
|
|
|
|
# Check for research methodologies
|
|
methods = [
|
|
'methodology', 'experiment', 'hypothesis', 'theoretical', 'empirical',
|
|
'qualitative', 'quantitative', 'data', 'analysis', 'statistical', 'results',
|
|
'findings', 'conclusion', 'meta-analysis', 'systematic review', 'clinical trial'
|
|
]
|
|
|
|
# Check for academic fields
|
|
fields = [
|
|
'science', 'physics', 'chemistry', 'biology', 'psychology', 'sociology',
|
|
'economics', 'history', 'philosophy', 'engineering', 'computer science',
|
|
'medicine', 'mathematics', 'geology', 'astronomy', 'linguistics'
|
|
]
|
|
|
|
# Count matches
|
|
academic_term_count = sum(1 for term in academic_terms if term in query_lower)
|
|
method_count = sum(1 for method in methods if method in query_lower)
|
|
field_count = sum(1 for field in fields if field in query_lower)
|
|
|
|
# Check for common academic question patterns
|
|
academic_patterns = [
|
|
'what does research say about',
|
|
'what studies show',
|
|
'according to research',
|
|
'scholarly view',
|
|
'academic consensus',
|
|
'published papers on',
|
|
'recent studies on',
|
|
'literature review',
|
|
'research findings',
|
|
'scientific evidence'
|
|
]
|
|
|
|
pattern_matches = sum(1 for pattern in academic_patterns if pattern in query_lower)
|
|
|
|
# Determine if this is likely an academic query
|
|
# Either multiple academic terms, or a combination of terms, methods, and fields
|
|
is_academic = (
|
|
academic_term_count >= 2 or
|
|
pattern_matches >= 1 or
|
|
(academic_term_count >= 1 and (method_count >= 1 or field_count >= 1)) or
|
|
(method_count >= 1 and field_count >= 1)
|
|
)
|
|
|
|
return is_academic
|
|
|
|
def _is_code_query(self, query: str, classification: Dict[str, Any]) -> bool:
|
|
"""
|
|
Determine if a query is related to programming or code.
|
|
|
|
Args:
|
|
query: The original user query
|
|
classification: The query classification
|
|
|
|
Returns:
|
|
True if the query is about programming or code, False otherwise
|
|
"""
|
|
query_lower = query.lower()
|
|
|
|
# Check for programming languages and technologies
|
|
programming_langs = [
|
|
'python', 'javascript', 'java', 'c++', 'c#', 'ruby', 'go', 'rust',
|
|
'php', 'swift', 'kotlin', 'typescript', 'perl', 'scala', 'r',
|
|
'html', 'css', 'sql', 'bash', 'powershell', 'dart', 'julia'
|
|
]
|
|
|
|
# Check for programming frameworks and libraries
|
|
frameworks = [
|
|
'react', 'angular', 'vue', 'django', 'flask', 'spring', 'laravel',
|
|
'express', 'tensorflow', 'pytorch', 'pandas', 'numpy', 'scikit-learn',
|
|
'bootstrap', 'jquery', 'node', 'rails', 'asp.net', 'unity', 'flutter',
|
|
'pytorch', 'keras', '.net', 'core', 'maven', 'gradle', 'npm', 'pip'
|
|
]
|
|
|
|
# Check for programming concepts and terms
|
|
programming_terms = [
|
|
'algorithm', 'function', 'class', 'method', 'variable', 'object', 'array',
|
|
'string', 'integer', 'boolean', 'list', 'dictionary', 'hash', 'loop',
|
|
'recursion', 'inheritance', 'interface', 'api', 'rest', 'json', 'xml',
|
|
'database', 'query', 'schema', 'framework', 'library', 'package', 'module',
|
|
'dependency', 'bug', 'error', 'exception', 'debugging', 'compiler', 'runtime',
|
|
'syntax', 'parameter', 'argument', 'return', 'value', 'reference', 'pointer',
|
|
'memory', 'stack', 'heap', 'thread', 'async', 'await', 'promise', 'callback',
|
|
'event', 'listener', 'handler', 'middleware', 'frontend', 'backend', 'fullstack',
|
|
'devops', 'ci/cd', 'docker', 'kubernetes', 'git', 'github', 'bitbucket', 'gitlab'
|
|
]
|
|
|
|
# Check for programming question patterns
|
|
code_patterns = [
|
|
'how to code', 'how do i program', 'how to program', 'how to implement',
|
|
'code example', 'example code', 'code snippet', 'write a function',
|
|
'write a program', 'debugging', 'error message', 'getting error',
|
|
'code review', 'refactor', 'optimize', 'performance issue',
|
|
'best practice', 'design pattern', 'architecture', 'software design',
|
|
'algorithm for', 'data structure', 'time complexity', 'space complexity',
|
|
'big o', 'optimize code', 'refactor code', 'clean code', 'technical debt',
|
|
'unit test', 'integration test', 'test coverage', 'mock', 'stub'
|
|
]
|
|
|
|
# Count matches
|
|
lang_count = sum(1 for lang in programming_langs if lang in query_lower)
|
|
framework_count = sum(1 for framework in frameworks if framework in query_lower)
|
|
term_count = sum(1 for term in programming_terms if term in query_lower)
|
|
pattern_count = sum(1 for pattern in code_patterns if pattern in query_lower)
|
|
|
|
# Check if the query contains code or a code block (denoted by backticks or indentation)
|
|
contains_code_block = '```' in query or any(line.strip().startswith(' ') for line in query.split('\n'))
|
|
|
|
# Determine if this is likely a code-related query
|
|
is_code = (
|
|
lang_count >= 1 or
|
|
framework_count >= 1 or
|
|
term_count >= 2 or
|
|
pattern_count >= 1 or
|
|
contains_code_block or
|
|
(lang_count + framework_count + term_count >= 2)
|
|
)
|
|
|
|
return is_code
|
|
|
|
async def generate_search_queries(self, structured_query: Dict[str, Any],
|
|
search_engines: List[str]) -> Dict[str, Any]:
|
|
"""
|
|
Generate optimized search queries for different search engines.
|
|
|
|
Args:
|
|
structured_query: The structured query
|
|
search_engines: List of search engines to generate queries for
|
|
|
|
Returns:
|
|
Updated structured query with search queries
|
|
"""
|
|
# Use the enhanced query for generating search queries
|
|
enhanced_query = structured_query['enhanced_query']
|
|
|
|
# Generate search queries for each engine
|
|
search_queries = await self.llm_interface.generate_search_queries(
|
|
enhanced_query, search_engines
|
|
)
|
|
|
|
# Add search queries to the structured query
|
|
structured_query['search_queries'] = search_queries
|
|
|
|
return structured_query
|
|
|
|
|
|
# Create a singleton instance for global use
|
|
query_processor = QueryProcessor()
|
|
|
|
|
|
def get_query_processor() -> QueryProcessor:
|
|
"""
|
|
Get the global query processor instance.
|
|
|
|
Returns:
|
|
QueryProcessor instance
|
|
"""
|
|
return query_processor
|