diff --git a/query/llm_interface.py b/query/llm_interface.py index 2f98e0e..ed18090 100644 --- a/query/llm_interface.py +++ b/query/llm_interface.py @@ -127,7 +127,28 @@ class LLMInterface: """ try: params = self._get_completion_params() - params['messages'] = messages + + # Special handling for Gemini models + if "gemini" in params.get('model', '').lower(): + # Format messages for Gemini models + formatted_messages = [] + for msg in messages: + role = msg['role'] + # Gemini uses 'user' and 'model' roles (not 'assistant') + if role == 'assistant': + role = 'model' + # For system messages, convert to user messages with a prefix + if role == 'system': + formatted_messages.append({ + "role": "user", + "content": f"System instruction: {msg['content']}" + }) + else: + formatted_messages.append({"role": role, "content": msg['content']}) + params['messages'] = formatted_messages + else: + params['messages'] = messages + params['stream'] = stream response = completion(**params) diff --git a/report/report_synthesis.py b/report/report_synthesis.py index 5f4bcec..ca0eacd 100644 --- a/report/report_synthesis.py +++ b/report/report_synthesis.py @@ -123,7 +123,28 @@ class ReportSynthesizer: """ try: params = self._get_completion_params() - params['messages'] = messages + + # Special handling for Gemini models + if "gemini" in params.get('model', '').lower(): + # Format messages for Gemini models + formatted_messages = [] + for msg in messages: + role = msg['role'] + # Gemini uses 'user' and 'model' roles (not 'assistant') + if role == 'assistant': + role = 'model' + # For system messages, convert to user messages with a prefix + if role == 'system': + formatted_messages.append({ + "role": "user", + "content": f"System instruction: {msg['content']}" + }) + else: + formatted_messages.append({"role": role, "content": msg['content']}) + params['messages'] = formatted_messages + else: + params['messages'] = messages + params['stream'] = stream logger.info(f"Generating completion with model: {params.get('model')}")