Fix reference formatting for Gemini models and improve message handling
This commit is contained in:
parent
d4beb73a7a
commit
8cd2f900c1
|
@ -127,7 +127,7 @@ class LLMInterface:
|
|||
|
||||
return params
|
||||
|
||||
def generate_completion(self, messages: List[Dict[str, str]], stream: bool = False) -> Union[str, Any]:
|
||||
async def generate_completion(self, messages: List[Dict[str, str]], stream: bool = False) -> Union[str, Any]:
|
||||
"""
|
||||
Generate a completion using the configured LLM.
|
||||
|
||||
|
@ -139,41 +139,64 @@ class LLMInterface:
|
|||
If stream is False, returns the completion text as a string
|
||||
If stream is True, returns the completion response object for streaming
|
||||
"""
|
||||
try:
|
||||
params = self._get_completion_params()
|
||||
# Get provider from model config
|
||||
provider = self.model_config.get('provider', 'openai').lower()
|
||||
|
||||
# Special handling for Gemini models
|
||||
if "gemini" in params.get('model', '').lower():
|
||||
# Format messages for Gemini models
|
||||
# Special handling for Gemini models - they use 'user' and 'model' roles
|
||||
if provider == 'gemini':
|
||||
formatted_messages = []
|
||||
for msg in messages:
|
||||
role = msg['role']
|
||||
# Gemini uses 'user' and 'model' roles (not 'assistant')
|
||||
if role == 'assistant':
|
||||
role = 'model'
|
||||
# For system messages, convert to user messages with a prefix
|
||||
if role == 'system':
|
||||
# Map 'system' to 'user' for the first message
|
||||
if role == 'system' and not formatted_messages:
|
||||
formatted_messages.append({
|
||||
"role": "user",
|
||||
"content": f"System instruction: {msg['content']}"
|
||||
'role': 'user',
|
||||
'content': msg['content']
|
||||
})
|
||||
# Map 'assistant' to 'model'
|
||||
elif role == 'assistant':
|
||||
formatted_messages.append({
|
||||
'role': 'model',
|
||||
'content': msg['content']
|
||||
})
|
||||
# Keep 'user' as is
|
||||
else:
|
||||
formatted_messages.append({"role": role, "content": msg['content']})
|
||||
params['messages'] = formatted_messages
|
||||
formatted_messages.append(msg)
|
||||
else:
|
||||
params['messages'] = messages
|
||||
formatted_messages = messages
|
||||
|
||||
params['stream'] = stream
|
||||
|
||||
response = completion(**params)
|
||||
# Get completion parameters
|
||||
params = self._get_completion_params()
|
||||
|
||||
try:
|
||||
# Generate completion
|
||||
if stream:
|
||||
response = litellm.completion(
|
||||
messages=formatted_messages,
|
||||
stream=True,
|
||||
**params
|
||||
)
|
||||
return response
|
||||
else:
|
||||
return response.choices[0].message.content
|
||||
response = litellm.completion(
|
||||
messages=formatted_messages,
|
||||
**params
|
||||
)
|
||||
|
||||
# Extract content from response
|
||||
content = response.choices[0].message.content
|
||||
|
||||
# Process thinking tags if enabled
|
||||
if hasattr(self, 'process_thinking_tags') and self.process_thinking_tags:
|
||||
content = self._process_thinking_tags(content)
|
||||
|
||||
return content
|
||||
except Exception as e:
|
||||
print(f"Error generating completion: {e}")
|
||||
return f"Error: {str(e)}"
|
||||
error_msg = f"Error generating completion: {str(e)}"
|
||||
print(error_msg)
|
||||
|
||||
# Return error message in a user-friendly format
|
||||
return f"I encountered an error while processing your request: {str(e)}"
|
||||
|
||||
def enhance_query(self, query: str) -> str:
|
||||
"""
|
||||
|
|
|
@ -135,57 +135,64 @@ class ReportSynthesizer:
|
|||
If stream is False, returns the completion text as a string
|
||||
If stream is True, returns the completion response object for streaming
|
||||
"""
|
||||
try:
|
||||
params = self._get_completion_params()
|
||||
# Get provider from model config
|
||||
provider = self.model_config.get('provider', 'groq').lower()
|
||||
|
||||
# Special handling for Gemini models
|
||||
if "gemini" in params.get('model', '').lower():
|
||||
# Format messages for Gemini models
|
||||
# Special handling for Gemini models - they use 'user' and 'model' roles
|
||||
if provider == 'gemini':
|
||||
formatted_messages = []
|
||||
for msg in messages:
|
||||
role = msg['role']
|
||||
# Gemini uses 'user' and 'model' roles (not 'assistant')
|
||||
if role == 'assistant':
|
||||
role = 'model'
|
||||
# For system messages, convert to user messages with a prefix
|
||||
if role == 'system':
|
||||
# Map 'system' to 'user' for the first message
|
||||
if role == 'system' and not formatted_messages:
|
||||
formatted_messages.append({
|
||||
"role": "user",
|
||||
"content": f"System instruction: {msg['content']}"
|
||||
'role': 'user',
|
||||
'content': msg['content']
|
||||
})
|
||||
# Map 'assistant' to 'model'
|
||||
elif role == 'assistant':
|
||||
formatted_messages.append({
|
||||
'role': 'model',
|
||||
'content': msg['content']
|
||||
})
|
||||
# Keep 'user' as is
|
||||
else:
|
||||
formatted_messages.append({"role": role, "content": msg['content']})
|
||||
params['messages'] = formatted_messages
|
||||
formatted_messages.append(msg)
|
||||
else:
|
||||
params['messages'] = messages
|
||||
formatted_messages = messages
|
||||
|
||||
params['stream'] = stream
|
||||
|
||||
logger.info(f"Generating completion with model: {params.get('model')}")
|
||||
logger.info(f"Provider: {self.model_config.get('provider')}")
|
||||
|
||||
response = completion(**params)
|
||||
# Get completion parameters
|
||||
params = self._get_completion_params()
|
||||
|
||||
try:
|
||||
# Generate completion
|
||||
if stream:
|
||||
response = litellm.completion(
|
||||
messages=formatted_messages,
|
||||
stream=True,
|
||||
**params
|
||||
)
|
||||
return response
|
||||
else:
|
||||
response = litellm.completion(
|
||||
messages=formatted_messages,
|
||||
**params
|
||||
)
|
||||
|
||||
# Extract content from response
|
||||
content = response.choices[0].message.content
|
||||
|
||||
# Process <thinking> tags if enabled
|
||||
# Process thinking tags if enabled
|
||||
if self.process_thinking_tags:
|
||||
content = self._process_thinking_tags(content)
|
||||
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating completion: {e}")
|
||||
logger.error(f"Model params: {params}")
|
||||
error_msg = f"Error generating completion: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
|
||||
# More detailed error for debugging
|
||||
if hasattr(e, '__dict__'):
|
||||
for key, value in e.__dict__.items():
|
||||
logger.error(f"Error detail - {key}: {value}")
|
||||
|
||||
return f"Error: {str(e)}"
|
||||
# Return error message in a user-friendly format
|
||||
return f"I encountered an error while processing your request: {str(e)}"
|
||||
|
||||
def _process_thinking_tags(self, content: str) -> str:
|
||||
"""
|
||||
|
@ -360,13 +367,36 @@ class ReportSynthesizer:
|
|||
detail_level_manager = get_report_detail_level_manager()
|
||||
template = detail_level_manager.get_template_modifier(detail_level, query_type)
|
||||
|
||||
# Add specific instructions for references formatting
|
||||
reference_instructions = """
|
||||
When including references, use a consistent format:
|
||||
|
||||
[1] Author(s). Title. Publication. Year. URL (if available)
|
||||
|
||||
If author information is not available, use the website or organization name.
|
||||
|
||||
Always ensure the References section is complete and properly formatted at the end of the report.
|
||||
Do not use placeholders like "Document X" for references - provide actual titles.
|
||||
Ensure all references are properly closed with brackets and there are no incomplete references.
|
||||
"""
|
||||
|
||||
# Special handling for Gemini models
|
||||
if "gemini" in self.model_name.lower():
|
||||
reference_instructions += """
|
||||
IMPORTANT: Due to token limitations, ensure the References section is completed properly.
|
||||
If you feel you might run out of tokens, start the References section earlier and make it more concise.
|
||||
Never leave the References section incomplete or cut off mid-reference.
|
||||
"""
|
||||
|
||||
# Create the prompt for synthesizing the report
|
||||
messages = [
|
||||
{"role": "system", "content": f"""You are an expert research assistant tasked with creating comprehensive, well-structured reports.
|
||||
{template}
|
||||
|
||||
Format the report in Markdown with clear headings, subheadings, and bullet points where appropriate.
|
||||
Make the report readable, engaging, and informative while maintaining academic rigor."""},
|
||||
Make the report readable, engaging, and informative while maintaining academic rigor.
|
||||
|
||||
{reference_instructions}"""},
|
||||
{"role": "user", "content": f"""Query: {query}
|
||||
|
||||
Information from sources:
|
||||
|
@ -378,6 +408,38 @@ class ReportSynthesizer:
|
|||
# Generate the report
|
||||
report = await self.generate_completion(messages)
|
||||
|
||||
# Check if the report might be cut off at the end
|
||||
if report.strip().endswith('[') or report.strip().endswith(']') or report.strip().endswith('...'):
|
||||
logger.warning("Report appears to be cut off at the end. Attempting to fix references section.")
|
||||
|
||||
# Try to fix the references section by generating it separately
|
||||
try:
|
||||
# Extract what we have so far without the incomplete references
|
||||
if "References" in report:
|
||||
report_without_refs = report.split("References")[0].strip()
|
||||
else:
|
||||
report_without_refs = report
|
||||
|
||||
# Generate just the references section
|
||||
ref_messages = [
|
||||
{"role": "system", "content": "You are an expert at formatting reference lists. Create a properly formatted References section for the following documents:"},
|
||||
{"role": "user", "content": f"""Here are the documents used in the report:
|
||||
|
||||
{context}
|
||||
|
||||
Create a complete, properly formatted References section in Markdown format.
|
||||
Use the format: [1] Title. Source URL
|
||||
Make sure all references are complete and properly formatted."""}
|
||||
]
|
||||
|
||||
references = await self.generate_completion(ref_messages)
|
||||
|
||||
# Combine the report with the fixed references
|
||||
report = f"{report_without_refs}\n\n## References\n\n{references}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fixing references section: {str(e)}")
|
||||
|
||||
return report
|
||||
|
||||
async def synthesize_report(self, chunks: List[Dict[str, Any]], query: str, query_type: str = "exploratory", detail_level: str = "standard") -> str:
|
||||
|
|
Loading…
Reference in New Issue