Fix reference formatting for Gemini models and improve message handling

This commit is contained in:
Steve White 2025-02-28 16:32:53 -06:00
parent d4beb73a7a
commit 8cd2f900c1
2 changed files with 156 additions and 71 deletions

View File

@ -127,7 +127,7 @@ class LLMInterface:
return params return params
def generate_completion(self, messages: List[Dict[str, str]], stream: bool = False) -> Union[str, Any]: async def generate_completion(self, messages: List[Dict[str, str]], stream: bool = False) -> Union[str, Any]:
""" """
Generate a completion using the configured LLM. Generate a completion using the configured LLM.
@ -139,41 +139,64 @@ class LLMInterface:
If stream is False, returns the completion text as a string If stream is False, returns the completion text as a string
If stream is True, returns the completion response object for streaming If stream is True, returns the completion response object for streaming
""" """
# Get provider from model config
provider = self.model_config.get('provider', 'openai').lower()
# Special handling for Gemini models - they use 'user' and 'model' roles
if provider == 'gemini':
formatted_messages = []
for msg in messages:
role = msg['role']
# Map 'system' to 'user' for the first message
if role == 'system' and not formatted_messages:
formatted_messages.append({
'role': 'user',
'content': msg['content']
})
# Map 'assistant' to 'model'
elif role == 'assistant':
formatted_messages.append({
'role': 'model',
'content': msg['content']
})
# Keep 'user' as is
else:
formatted_messages.append(msg)
else:
formatted_messages = messages
# Get completion parameters
params = self._get_completion_params()
try: try:
params = self._get_completion_params() # Generate completion
# Special handling for Gemini models
if "gemini" in params.get('model', '').lower():
# Format messages for Gemini models
formatted_messages = []
for msg in messages:
role = msg['role']
# Gemini uses 'user' and 'model' roles (not 'assistant')
if role == 'assistant':
role = 'model'
# For system messages, convert to user messages with a prefix
if role == 'system':
formatted_messages.append({
"role": "user",
"content": f"System instruction: {msg['content']}"
})
else:
formatted_messages.append({"role": role, "content": msg['content']})
params['messages'] = formatted_messages
else:
params['messages'] = messages
params['stream'] = stream
response = completion(**params)
if stream: if stream:
response = litellm.completion(
messages=formatted_messages,
stream=True,
**params
)
return response return response
else: else:
return response.choices[0].message.content response = litellm.completion(
messages=formatted_messages,
**params
)
# Extract content from response
content = response.choices[0].message.content
# Process thinking tags if enabled
if hasattr(self, 'process_thinking_tags') and self.process_thinking_tags:
content = self._process_thinking_tags(content)
return content
except Exception as e: except Exception as e:
print(f"Error generating completion: {e}") error_msg = f"Error generating completion: {str(e)}"
return f"Error: {str(e)}" print(error_msg)
# Return error message in a user-friendly format
return f"I encountered an error while processing your request: {str(e)}"
def enhance_query(self, query: str) -> str: def enhance_query(self, query: str) -> str:
""" """

View File

@ -135,57 +135,64 @@ class ReportSynthesizer:
If stream is False, returns the completion text as a string If stream is False, returns the completion text as a string
If stream is True, returns the completion response object for streaming If stream is True, returns the completion response object for streaming
""" """
# Get provider from model config
provider = self.model_config.get('provider', 'groq').lower()
# Special handling for Gemini models - they use 'user' and 'model' roles
if provider == 'gemini':
formatted_messages = []
for msg in messages:
role = msg['role']
# Map 'system' to 'user' for the first message
if role == 'system' and not formatted_messages:
formatted_messages.append({
'role': 'user',
'content': msg['content']
})
# Map 'assistant' to 'model'
elif role == 'assistant':
formatted_messages.append({
'role': 'model',
'content': msg['content']
})
# Keep 'user' as is
else:
formatted_messages.append(msg)
else:
formatted_messages = messages
# Get completion parameters
params = self._get_completion_params()
try: try:
params = self._get_completion_params() # Generate completion
# Special handling for Gemini models
if "gemini" in params.get('model', '').lower():
# Format messages for Gemini models
formatted_messages = []
for msg in messages:
role = msg['role']
# Gemini uses 'user' and 'model' roles (not 'assistant')
if role == 'assistant':
role = 'model'
# For system messages, convert to user messages with a prefix
if role == 'system':
formatted_messages.append({
"role": "user",
"content": f"System instruction: {msg['content']}"
})
else:
formatted_messages.append({"role": role, "content": msg['content']})
params['messages'] = formatted_messages
else:
params['messages'] = messages
params['stream'] = stream
logger.info(f"Generating completion with model: {params.get('model')}")
logger.info(f"Provider: {self.model_config.get('provider')}")
response = completion(**params)
if stream: if stream:
response = litellm.completion(
messages=formatted_messages,
stream=True,
**params
)
return response return response
else: else:
response = litellm.completion(
messages=formatted_messages,
**params
)
# Extract content from response
content = response.choices[0].message.content content = response.choices[0].message.content
# Process <thinking> tags if enabled # Process thinking tags if enabled
if self.process_thinking_tags: if self.process_thinking_tags:
content = self._process_thinking_tags(content) content = self._process_thinking_tags(content)
return content return content
except Exception as e: except Exception as e:
logger.error(f"Error generating completion: {e}") error_msg = f"Error generating completion: {str(e)}"
logger.error(f"Model params: {params}") logger.error(error_msg)
# More detailed error for debugging # Return error message in a user-friendly format
if hasattr(e, '__dict__'): return f"I encountered an error while processing your request: {str(e)}"
for key, value in e.__dict__.items():
logger.error(f"Error detail - {key}: {value}")
return f"Error: {str(e)}"
def _process_thinking_tags(self, content: str) -> str: def _process_thinking_tags(self, content: str) -> str:
""" """
@ -360,13 +367,36 @@ class ReportSynthesizer:
detail_level_manager = get_report_detail_level_manager() detail_level_manager = get_report_detail_level_manager()
template = detail_level_manager.get_template_modifier(detail_level, query_type) template = detail_level_manager.get_template_modifier(detail_level, query_type)
# Add specific instructions for references formatting
reference_instructions = """
When including references, use a consistent format:
[1] Author(s). Title. Publication. Year. URL (if available)
If author information is not available, use the website or organization name.
Always ensure the References section is complete and properly formatted at the end of the report.
Do not use placeholders like "Document X" for references - provide actual titles.
Ensure all references are properly closed with brackets and there are no incomplete references.
"""
# Special handling for Gemini models
if "gemini" in self.model_name.lower():
reference_instructions += """
IMPORTANT: Due to token limitations, ensure the References section is completed properly.
If you feel you might run out of tokens, start the References section earlier and make it more concise.
Never leave the References section incomplete or cut off mid-reference.
"""
# Create the prompt for synthesizing the report # Create the prompt for synthesizing the report
messages = [ messages = [
{"role": "system", "content": f"""You are an expert research assistant tasked with creating comprehensive, well-structured reports. {"role": "system", "content": f"""You are an expert research assistant tasked with creating comprehensive, well-structured reports.
{template} {template}
Format the report in Markdown with clear headings, subheadings, and bullet points where appropriate. Format the report in Markdown with clear headings, subheadings, and bullet points where appropriate.
Make the report readable, engaging, and informative while maintaining academic rigor."""}, Make the report readable, engaging, and informative while maintaining academic rigor.
{reference_instructions}"""},
{"role": "user", "content": f"""Query: {query} {"role": "user", "content": f"""Query: {query}
Information from sources: Information from sources:
@ -378,6 +408,38 @@ class ReportSynthesizer:
# Generate the report # Generate the report
report = await self.generate_completion(messages) report = await self.generate_completion(messages)
# Check if the report might be cut off at the end
if report.strip().endswith('[') or report.strip().endswith(']') or report.strip().endswith('...'):
logger.warning("Report appears to be cut off at the end. Attempting to fix references section.")
# Try to fix the references section by generating it separately
try:
# Extract what we have so far without the incomplete references
if "References" in report:
report_without_refs = report.split("References")[0].strip()
else:
report_without_refs = report
# Generate just the references section
ref_messages = [
{"role": "system", "content": "You are an expert at formatting reference lists. Create a properly formatted References section for the following documents:"},
{"role": "user", "content": f"""Here are the documents used in the report:
{context}
Create a complete, properly formatted References section in Markdown format.
Use the format: [1] Title. Source URL
Make sure all references are complete and properly formatted."""}
]
references = await self.generate_completion(ref_messages)
# Combine the report with the fixed references
report = f"{report_without_refs}\n\n## References\n\n{references}"
except Exception as e:
logger.error(f"Error fixing references section: {str(e)}")
return report return report
async def synthesize_report(self, chunks: List[Dict[str, Any]], query: str, query_type: str = "exploratory", detail_level: str = "standard") -> str: async def synthesize_report(self, chunks: List[Dict[str, Any]], query: str, query_type: str = "exploratory", detail_level: str = "standard") -> str: