Fix reference formatting for Gemini models and improve message handling
This commit is contained in:
parent
d4beb73a7a
commit
8cd2f900c1
|
@ -127,7 +127,7 @@ class LLMInterface:
|
||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
def generate_completion(self, messages: List[Dict[str, str]], stream: bool = False) -> Union[str, Any]:
|
async def generate_completion(self, messages: List[Dict[str, str]], stream: bool = False) -> Union[str, Any]:
|
||||||
"""
|
"""
|
||||||
Generate a completion using the configured LLM.
|
Generate a completion using the configured LLM.
|
||||||
|
|
||||||
|
@ -139,41 +139,64 @@ class LLMInterface:
|
||||||
If stream is False, returns the completion text as a string
|
If stream is False, returns the completion text as a string
|
||||||
If stream is True, returns the completion response object for streaming
|
If stream is True, returns the completion response object for streaming
|
||||||
"""
|
"""
|
||||||
|
# Get provider from model config
|
||||||
|
provider = self.model_config.get('provider', 'openai').lower()
|
||||||
|
|
||||||
|
# Special handling for Gemini models - they use 'user' and 'model' roles
|
||||||
|
if provider == 'gemini':
|
||||||
|
formatted_messages = []
|
||||||
|
for msg in messages:
|
||||||
|
role = msg['role']
|
||||||
|
# Map 'system' to 'user' for the first message
|
||||||
|
if role == 'system' and not formatted_messages:
|
||||||
|
formatted_messages.append({
|
||||||
|
'role': 'user',
|
||||||
|
'content': msg['content']
|
||||||
|
})
|
||||||
|
# Map 'assistant' to 'model'
|
||||||
|
elif role == 'assistant':
|
||||||
|
formatted_messages.append({
|
||||||
|
'role': 'model',
|
||||||
|
'content': msg['content']
|
||||||
|
})
|
||||||
|
# Keep 'user' as is
|
||||||
|
else:
|
||||||
|
formatted_messages.append(msg)
|
||||||
|
else:
|
||||||
|
formatted_messages = messages
|
||||||
|
|
||||||
|
# Get completion parameters
|
||||||
|
params = self._get_completion_params()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
params = self._get_completion_params()
|
# Generate completion
|
||||||
|
|
||||||
# Special handling for Gemini models
|
|
||||||
if "gemini" in params.get('model', '').lower():
|
|
||||||
# Format messages for Gemini models
|
|
||||||
formatted_messages = []
|
|
||||||
for msg in messages:
|
|
||||||
role = msg['role']
|
|
||||||
# Gemini uses 'user' and 'model' roles (not 'assistant')
|
|
||||||
if role == 'assistant':
|
|
||||||
role = 'model'
|
|
||||||
# For system messages, convert to user messages with a prefix
|
|
||||||
if role == 'system':
|
|
||||||
formatted_messages.append({
|
|
||||||
"role": "user",
|
|
||||||
"content": f"System instruction: {msg['content']}"
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
formatted_messages.append({"role": role, "content": msg['content']})
|
|
||||||
params['messages'] = formatted_messages
|
|
||||||
else:
|
|
||||||
params['messages'] = messages
|
|
||||||
|
|
||||||
params['stream'] = stream
|
|
||||||
|
|
||||||
response = completion(**params)
|
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
response = litellm.completion(
|
||||||
|
messages=formatted_messages,
|
||||||
|
stream=True,
|
||||||
|
**params
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
else:
|
else:
|
||||||
return response.choices[0].message.content
|
response = litellm.completion(
|
||||||
|
messages=formatted_messages,
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract content from response
|
||||||
|
content = response.choices[0].message.content
|
||||||
|
|
||||||
|
# Process thinking tags if enabled
|
||||||
|
if hasattr(self, 'process_thinking_tags') and self.process_thinking_tags:
|
||||||
|
content = self._process_thinking_tags(content)
|
||||||
|
|
||||||
|
return content
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error generating completion: {e}")
|
error_msg = f"Error generating completion: {str(e)}"
|
||||||
return f"Error: {str(e)}"
|
print(error_msg)
|
||||||
|
|
||||||
|
# Return error message in a user-friendly format
|
||||||
|
return f"I encountered an error while processing your request: {str(e)}"
|
||||||
|
|
||||||
def enhance_query(self, query: str) -> str:
|
def enhance_query(self, query: str) -> str:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -135,57 +135,64 @@ class ReportSynthesizer:
|
||||||
If stream is False, returns the completion text as a string
|
If stream is False, returns the completion text as a string
|
||||||
If stream is True, returns the completion response object for streaming
|
If stream is True, returns the completion response object for streaming
|
||||||
"""
|
"""
|
||||||
|
# Get provider from model config
|
||||||
|
provider = self.model_config.get('provider', 'groq').lower()
|
||||||
|
|
||||||
|
# Special handling for Gemini models - they use 'user' and 'model' roles
|
||||||
|
if provider == 'gemini':
|
||||||
|
formatted_messages = []
|
||||||
|
for msg in messages:
|
||||||
|
role = msg['role']
|
||||||
|
# Map 'system' to 'user' for the first message
|
||||||
|
if role == 'system' and not formatted_messages:
|
||||||
|
formatted_messages.append({
|
||||||
|
'role': 'user',
|
||||||
|
'content': msg['content']
|
||||||
|
})
|
||||||
|
# Map 'assistant' to 'model'
|
||||||
|
elif role == 'assistant':
|
||||||
|
formatted_messages.append({
|
||||||
|
'role': 'model',
|
||||||
|
'content': msg['content']
|
||||||
|
})
|
||||||
|
# Keep 'user' as is
|
||||||
|
else:
|
||||||
|
formatted_messages.append(msg)
|
||||||
|
else:
|
||||||
|
formatted_messages = messages
|
||||||
|
|
||||||
|
# Get completion parameters
|
||||||
|
params = self._get_completion_params()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
params = self._get_completion_params()
|
# Generate completion
|
||||||
|
|
||||||
# Special handling for Gemini models
|
|
||||||
if "gemini" in params.get('model', '').lower():
|
|
||||||
# Format messages for Gemini models
|
|
||||||
formatted_messages = []
|
|
||||||
for msg in messages:
|
|
||||||
role = msg['role']
|
|
||||||
# Gemini uses 'user' and 'model' roles (not 'assistant')
|
|
||||||
if role == 'assistant':
|
|
||||||
role = 'model'
|
|
||||||
# For system messages, convert to user messages with a prefix
|
|
||||||
if role == 'system':
|
|
||||||
formatted_messages.append({
|
|
||||||
"role": "user",
|
|
||||||
"content": f"System instruction: {msg['content']}"
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
formatted_messages.append({"role": role, "content": msg['content']})
|
|
||||||
params['messages'] = formatted_messages
|
|
||||||
else:
|
|
||||||
params['messages'] = messages
|
|
||||||
|
|
||||||
params['stream'] = stream
|
|
||||||
|
|
||||||
logger.info(f"Generating completion with model: {params.get('model')}")
|
|
||||||
logger.info(f"Provider: {self.model_config.get('provider')}")
|
|
||||||
|
|
||||||
response = completion(**params)
|
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
response = litellm.completion(
|
||||||
|
messages=formatted_messages,
|
||||||
|
stream=True,
|
||||||
|
**params
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
else:
|
else:
|
||||||
|
response = litellm.completion(
|
||||||
|
messages=formatted_messages,
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract content from response
|
||||||
content = response.choices[0].message.content
|
content = response.choices[0].message.content
|
||||||
|
|
||||||
# Process <thinking> tags if enabled
|
# Process thinking tags if enabled
|
||||||
if self.process_thinking_tags:
|
if self.process_thinking_tags:
|
||||||
content = self._process_thinking_tags(content)
|
content = self._process_thinking_tags(content)
|
||||||
|
|
||||||
return content
|
return content
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error generating completion: {e}")
|
error_msg = f"Error generating completion: {str(e)}"
|
||||||
logger.error(f"Model params: {params}")
|
logger.error(error_msg)
|
||||||
|
|
||||||
# More detailed error for debugging
|
# Return error message in a user-friendly format
|
||||||
if hasattr(e, '__dict__'):
|
return f"I encountered an error while processing your request: {str(e)}"
|
||||||
for key, value in e.__dict__.items():
|
|
||||||
logger.error(f"Error detail - {key}: {value}")
|
|
||||||
|
|
||||||
return f"Error: {str(e)}"
|
|
||||||
|
|
||||||
def _process_thinking_tags(self, content: str) -> str:
|
def _process_thinking_tags(self, content: str) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -360,13 +367,36 @@ class ReportSynthesizer:
|
||||||
detail_level_manager = get_report_detail_level_manager()
|
detail_level_manager = get_report_detail_level_manager()
|
||||||
template = detail_level_manager.get_template_modifier(detail_level, query_type)
|
template = detail_level_manager.get_template_modifier(detail_level, query_type)
|
||||||
|
|
||||||
|
# Add specific instructions for references formatting
|
||||||
|
reference_instructions = """
|
||||||
|
When including references, use a consistent format:
|
||||||
|
|
||||||
|
[1] Author(s). Title. Publication. Year. URL (if available)
|
||||||
|
|
||||||
|
If author information is not available, use the website or organization name.
|
||||||
|
|
||||||
|
Always ensure the References section is complete and properly formatted at the end of the report.
|
||||||
|
Do not use placeholders like "Document X" for references - provide actual titles.
|
||||||
|
Ensure all references are properly closed with brackets and there are no incomplete references.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Special handling for Gemini models
|
||||||
|
if "gemini" in self.model_name.lower():
|
||||||
|
reference_instructions += """
|
||||||
|
IMPORTANT: Due to token limitations, ensure the References section is completed properly.
|
||||||
|
If you feel you might run out of tokens, start the References section earlier and make it more concise.
|
||||||
|
Never leave the References section incomplete or cut off mid-reference.
|
||||||
|
"""
|
||||||
|
|
||||||
# Create the prompt for synthesizing the report
|
# Create the prompt for synthesizing the report
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": f"""You are an expert research assistant tasked with creating comprehensive, well-structured reports.
|
{"role": "system", "content": f"""You are an expert research assistant tasked with creating comprehensive, well-structured reports.
|
||||||
{template}
|
{template}
|
||||||
|
|
||||||
Format the report in Markdown with clear headings, subheadings, and bullet points where appropriate.
|
Format the report in Markdown with clear headings, subheadings, and bullet points where appropriate.
|
||||||
Make the report readable, engaging, and informative while maintaining academic rigor."""},
|
Make the report readable, engaging, and informative while maintaining academic rigor.
|
||||||
|
|
||||||
|
{reference_instructions}"""},
|
||||||
{"role": "user", "content": f"""Query: {query}
|
{"role": "user", "content": f"""Query: {query}
|
||||||
|
|
||||||
Information from sources:
|
Information from sources:
|
||||||
|
@ -378,6 +408,38 @@ class ReportSynthesizer:
|
||||||
# Generate the report
|
# Generate the report
|
||||||
report = await self.generate_completion(messages)
|
report = await self.generate_completion(messages)
|
||||||
|
|
||||||
|
# Check if the report might be cut off at the end
|
||||||
|
if report.strip().endswith('[') or report.strip().endswith(']') or report.strip().endswith('...'):
|
||||||
|
logger.warning("Report appears to be cut off at the end. Attempting to fix references section.")
|
||||||
|
|
||||||
|
# Try to fix the references section by generating it separately
|
||||||
|
try:
|
||||||
|
# Extract what we have so far without the incomplete references
|
||||||
|
if "References" in report:
|
||||||
|
report_without_refs = report.split("References")[0].strip()
|
||||||
|
else:
|
||||||
|
report_without_refs = report
|
||||||
|
|
||||||
|
# Generate just the references section
|
||||||
|
ref_messages = [
|
||||||
|
{"role": "system", "content": "You are an expert at formatting reference lists. Create a properly formatted References section for the following documents:"},
|
||||||
|
{"role": "user", "content": f"""Here are the documents used in the report:
|
||||||
|
|
||||||
|
{context}
|
||||||
|
|
||||||
|
Create a complete, properly formatted References section in Markdown format.
|
||||||
|
Use the format: [1] Title. Source URL
|
||||||
|
Make sure all references are complete and properly formatted."""}
|
||||||
|
]
|
||||||
|
|
||||||
|
references = await self.generate_completion(ref_messages)
|
||||||
|
|
||||||
|
# Combine the report with the fixed references
|
||||||
|
report = f"{report_without_refs}\n\n## References\n\n{references}"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fixing references section: {str(e)}")
|
||||||
|
|
||||||
return report
|
return report
|
||||||
|
|
||||||
async def synthesize_report(self, chunks: List[Dict[str, Any]], query: str, query_type: str = "exploratory", detail_level: str = "standard") -> str:
|
async def synthesize_report(self, chunks: List[Dict[str, Any]], query: str, query_type: str = "exploratory", detail_level: str = "standard") -> str:
|
||||||
|
|
Loading…
Reference in New Issue