From 8cd2f900c136178e4bc8775a75ce9ec7843830e7 Mon Sep 17 00:00:00 2001 From: Steve White Date: Fri, 28 Feb 2025 16:32:53 -0600 Subject: [PATCH] Fix reference formatting for Gemini models and improve message handling --- query/llm_interface.py | 85 ++++++++++++++-------- report/report_synthesis.py | 142 ++++++++++++++++++++++++++----------- 2 files changed, 156 insertions(+), 71 deletions(-) diff --git a/query/llm_interface.py b/query/llm_interface.py index 7d818b8..99af01e 100644 --- a/query/llm_interface.py +++ b/query/llm_interface.py @@ -127,7 +127,7 @@ class LLMInterface: return params - def generate_completion(self, messages: List[Dict[str, str]], stream: bool = False) -> Union[str, Any]: + async def generate_completion(self, messages: List[Dict[str, str]], stream: bool = False) -> Union[str, Any]: """ Generate a completion using the configured LLM. @@ -139,41 +139,64 @@ class LLMInterface: If stream is False, returns the completion text as a string If stream is True, returns the completion response object for streaming """ + # Get provider from model config + provider = self.model_config.get('provider', 'openai').lower() + + # Special handling for Gemini models - they use 'user' and 'model' roles + if provider == 'gemini': + formatted_messages = [] + for msg in messages: + role = msg['role'] + # Map 'system' to 'user' for the first message + if role == 'system' and not formatted_messages: + formatted_messages.append({ + 'role': 'user', + 'content': msg['content'] + }) + # Map 'assistant' to 'model' + elif role == 'assistant': + formatted_messages.append({ + 'role': 'model', + 'content': msg['content'] + }) + # Keep 'user' as is + else: + formatted_messages.append(msg) + else: + formatted_messages = messages + + # Get completion parameters + params = self._get_completion_params() + try: - params = self._get_completion_params() - - # Special handling for Gemini models - if "gemini" in params.get('model', '').lower(): - # Format messages for Gemini models - formatted_messages = [] - for msg in messages: - role = msg['role'] - # Gemini uses 'user' and 'model' roles (not 'assistant') - if role == 'assistant': - role = 'model' - # For system messages, convert to user messages with a prefix - if role == 'system': - formatted_messages.append({ - "role": "user", - "content": f"System instruction: {msg['content']}" - }) - else: - formatted_messages.append({"role": role, "content": msg['content']}) - params['messages'] = formatted_messages - else: - params['messages'] = messages - - params['stream'] = stream - - response = completion(**params) - + # Generate completion if stream: + response = litellm.completion( + messages=formatted_messages, + stream=True, + **params + ) return response else: - return response.choices[0].message.content + response = litellm.completion( + messages=formatted_messages, + **params + ) + + # Extract content from response + content = response.choices[0].message.content + + # Process thinking tags if enabled + if hasattr(self, 'process_thinking_tags') and self.process_thinking_tags: + content = self._process_thinking_tags(content) + + return content except Exception as e: - print(f"Error generating completion: {e}") - return f"Error: {str(e)}" + error_msg = f"Error generating completion: {str(e)}" + print(error_msg) + + # Return error message in a user-friendly format + return f"I encountered an error while processing your request: {str(e)}" def enhance_query(self, query: str) -> str: """ diff --git a/report/report_synthesis.py b/report/report_synthesis.py index 5a11ebc..d6e13d7 100644 --- a/report/report_synthesis.py +++ b/report/report_synthesis.py @@ -135,57 +135,64 @@ class ReportSynthesizer: If stream is False, returns the completion text as a string If stream is True, returns the completion response object for streaming """ + # Get provider from model config + provider = self.model_config.get('provider', 'groq').lower() + + # Special handling for Gemini models - they use 'user' and 'model' roles + if provider == 'gemini': + formatted_messages = [] + for msg in messages: + role = msg['role'] + # Map 'system' to 'user' for the first message + if role == 'system' and not formatted_messages: + formatted_messages.append({ + 'role': 'user', + 'content': msg['content'] + }) + # Map 'assistant' to 'model' + elif role == 'assistant': + formatted_messages.append({ + 'role': 'model', + 'content': msg['content'] + }) + # Keep 'user' as is + else: + formatted_messages.append(msg) + else: + formatted_messages = messages + + # Get completion parameters + params = self._get_completion_params() + try: - params = self._get_completion_params() - - # Special handling for Gemini models - if "gemini" in params.get('model', '').lower(): - # Format messages for Gemini models - formatted_messages = [] - for msg in messages: - role = msg['role'] - # Gemini uses 'user' and 'model' roles (not 'assistant') - if role == 'assistant': - role = 'model' - # For system messages, convert to user messages with a prefix - if role == 'system': - formatted_messages.append({ - "role": "user", - "content": f"System instruction: {msg['content']}" - }) - else: - formatted_messages.append({"role": role, "content": msg['content']}) - params['messages'] = formatted_messages - else: - params['messages'] = messages - - params['stream'] = stream - - logger.info(f"Generating completion with model: {params.get('model')}") - logger.info(f"Provider: {self.model_config.get('provider')}") - - response = completion(**params) - + # Generate completion if stream: + response = litellm.completion( + messages=formatted_messages, + stream=True, + **params + ) return response else: + response = litellm.completion( + messages=formatted_messages, + **params + ) + + # Extract content from response content = response.choices[0].message.content - # Process tags if enabled + # Process thinking tags if enabled if self.process_thinking_tags: content = self._process_thinking_tags(content) return content except Exception as e: - logger.error(f"Error generating completion: {e}") - logger.error(f"Model params: {params}") + error_msg = f"Error generating completion: {str(e)}" + logger.error(error_msg) - # More detailed error for debugging - if hasattr(e, '__dict__'): - for key, value in e.__dict__.items(): - logger.error(f"Error detail - {key}: {value}") - - return f"Error: {str(e)}" + # Return error message in a user-friendly format + return f"I encountered an error while processing your request: {str(e)}" def _process_thinking_tags(self, content: str) -> str: """ @@ -360,13 +367,36 @@ class ReportSynthesizer: detail_level_manager = get_report_detail_level_manager() template = detail_level_manager.get_template_modifier(detail_level, query_type) + # Add specific instructions for references formatting + reference_instructions = """ + When including references, use a consistent format: + + [1] Author(s). Title. Publication. Year. URL (if available) + + If author information is not available, use the website or organization name. + + Always ensure the References section is complete and properly formatted at the end of the report. + Do not use placeholders like "Document X" for references - provide actual titles. + Ensure all references are properly closed with brackets and there are no incomplete references. + """ + + # Special handling for Gemini models + if "gemini" in self.model_name.lower(): + reference_instructions += """ + IMPORTANT: Due to token limitations, ensure the References section is completed properly. + If you feel you might run out of tokens, start the References section earlier and make it more concise. + Never leave the References section incomplete or cut off mid-reference. + """ + # Create the prompt for synthesizing the report messages = [ {"role": "system", "content": f"""You are an expert research assistant tasked with creating comprehensive, well-structured reports. {template} Format the report in Markdown with clear headings, subheadings, and bullet points where appropriate. - Make the report readable, engaging, and informative while maintaining academic rigor."""}, + Make the report readable, engaging, and informative while maintaining academic rigor. + + {reference_instructions}"""}, {"role": "user", "content": f"""Query: {query} Information from sources: @@ -378,6 +408,38 @@ class ReportSynthesizer: # Generate the report report = await self.generate_completion(messages) + # Check if the report might be cut off at the end + if report.strip().endswith('[') or report.strip().endswith(']') or report.strip().endswith('...'): + logger.warning("Report appears to be cut off at the end. Attempting to fix references section.") + + # Try to fix the references section by generating it separately + try: + # Extract what we have so far without the incomplete references + if "References" in report: + report_without_refs = report.split("References")[0].strip() + else: + report_without_refs = report + + # Generate just the references section + ref_messages = [ + {"role": "system", "content": "You are an expert at formatting reference lists. Create a properly formatted References section for the following documents:"}, + {"role": "user", "content": f"""Here are the documents used in the report: + + {context} + + Create a complete, properly formatted References section in Markdown format. + Use the format: [1] Title. Source URL + Make sure all references are complete and properly formatted."""} + ] + + references = await self.generate_completion(ref_messages) + + # Combine the report with the fixed references + report = f"{report_without_refs}\n\n## References\n\n{references}" + + except Exception as e: + logger.error(f"Error fixing references section: {str(e)}") + return report async def synthesize_report(self, chunks: List[Dict[str, Any]], query: str, query_type: str = "exploratory", detail_level: str = "standard") -> str: