paper-system/llm_processor/client/client.go

327 lines
9.7 KiB
Go
Raw Permalink Normal View History

2025-01-24 15:26:47 +00:00
package client
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"math"
"net"
"net/http"
"os"
"strings"
"time"
"llm_processor/models"
)
const (
openRouterURL = "https://openrouter.ai/api/v1/chat/completions"
maxRetries = 5
initialDelay = 1 * time.Second
evaluationTimeout = 15 * time.Minute
requestTimeout = 5 * time.Minute
connectionTimeout = 2 * time.Minute
)
type OpenRouterClient struct {
apiKey string
httpClient *http.Client
logger *log.Logger
createClient func() *http.Client
}
func NewOpenRouterClient(apiKey string) *OpenRouterClient {
logFile, err := os.OpenFile("debug.log", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatalf("Failed to open debug.log: %v", err)
}
logger := log.New(io.MultiWriter(os.Stdout, logFile), "", log.LstdFlags|log.Lshortfile)
logger.Println("Initializing OpenRouter client")
createClient := func() *http.Client {
transport := &http.Transport{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 30 * time.Second,
DialContext: (&net.Dialer{
Timeout: connectionTimeout,
KeepAlive: 30 * time.Second,
}).DialContext,
TLSHandshakeTimeout: 10 * time.Second,
}
return &http.Client{
Timeout: requestTimeout,
Transport: transport,
}
}
client := &OpenRouterClient{
apiKey: apiKey,
httpClient: createClient(),
logger: logger,
createClient: createClient,
}
return client
}
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatMessage `json:"messages"`
}
type ChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ChatCompletionResponse struct {
Choices []struct {
Message ChatMessage `json:"message"`
} `json:"choices"`
}
func (c *OpenRouterClient) EvaluatePaper(ctx context.Context, paper models.Paper, criteria string, model string) (string, error) {
// Create a new context with evaluation timeout
evalCtx, cancel := context.WithTimeout(ctx, evaluationTimeout)
defer cancel()
startTime := time.Now()
c.logger.Printf("Starting evaluation for paper: %s\n", paper.Title)
c.logger.Printf("Evaluation timeout: %s\n", evaluationTimeout)
fmt.Printf("Starting evaluation for paper: %s at %s\n", paper.Title, startTime.Format(time.RFC3339))
prompt := fmt.Sprintf(`Evaluate this paper based on the following criteria:
%s
Paper Title: %s
Abstract: %s
Respond ONLY with a JSON object in this exact format:
{
"decision": "ACCEPT or REJECT",
"explanation": "Your explanation here"
}
Do not include any other information in your response.
IMPORTANT:
1. The decision MUST be either "ACCEPT" or "REJECT" (uppercase)
2. The explanation should be a clear, concise reason for your decision
3. Do not include any text outside the JSON object
4. Ensure the response is valid JSON (proper quotes and escaping)
5. Do not include any markdown or formatting`, criteria, paper.Title, paper.Abstract)
reqBody := ChatCompletionRequest{
Model: model,
Messages: []ChatMessage{
{
Role: "system",
Content: "You are a research paper evaluator. Respond only with the requested JSON format.",
},
{
Role: "user",
Content: prompt,
},
},
}
var lastErr error
for attempt := 0; attempt < maxRetries; attempt++ {
attemptStart := time.Now()
c.logger.Printf("Attempt %d started at %s\n", attempt+1, attemptStart.Format(time.RFC3339))
if attempt > 0 {
delay := time.Duration(math.Pow(2, float64(attempt))) * initialDelay
select {
case <-time.After(delay):
case <-ctx.Done():
return "", ctx.Err()
}
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("failed to marshal request body: %w", err)
}
req, err := http.NewRequestWithContext(evalCtx, "POST", openRouterURL, bytes.NewBuffer(jsonBody))
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+c.apiKey)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("HTTP-Referer", "https://github.com/stwhite/arvix")
req.Header.Set("X-Title", "ArXiv Paper Processor")
resp, err := c.httpClient.Do(req)
if err != nil {
// Log the specific error type
c.logger.Printf("Attempt %d error: %v\n", attempt+1, err)
// Handle context cancellation/timeout
if errors.Is(err, context.DeadlineExceeded) {
c.logger.Printf("Context deadline exceeded, retrying...\n")
lastErr = fmt.Errorf("context deadline exceeded")
continue
}
// On timeout errors, create a new client
if strings.Contains(err.Error(), "timeout") {
c.logger.Printf("Timeout detected, recreating HTTP client...\n")
c.httpClient = c.createClient()
}
lastErr = fmt.Errorf("attempt %d: %w", attempt+1, err)
continue
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
lastErr = fmt.Errorf("attempt %d: openrouter request failed: %s - %s", attempt+1, resp.Status, string(body))
continue
}
var completionResp ChatCompletionResponse
if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
lastErr = fmt.Errorf("attempt %d: failed to decode response: %w", attempt+1, err)
continue
}
if len(completionResp.Choices) == 0 {
lastErr = fmt.Errorf("attempt %d: no choices in response", attempt+1)
continue
}
rawContent := completionResp.Choices[0].Message.Content
// Try to parse as JSON first
var jsonResponse map[string]interface{}
err = json.Unmarshal([]byte(rawContent), &jsonResponse)
if err != nil {
// If direct JSON parsing fails, try extracting from markdown code block
startIdx := bytes.Index([]byte(rawContent), []byte("```json"))
if startIdx >= 0 {
startIdx += len("```json")
endIdx := bytes.Index([]byte(rawContent[startIdx:]), []byte("```"))
if endIdx >= 0 {
jsonContent := rawContent[startIdx : startIdx+endIdx]
err = json.Unmarshal([]byte(jsonContent), &jsonResponse)
if err != nil {
// If still failing, try to parse as raw JSON without code block
err = json.Unmarshal([]byte(rawContent), &jsonResponse)
}
}
}
}
if err == nil {
// Validate and normalize decision
if decision, ok := jsonResponse["decision"].(string); ok {
// Normalize decision value
normalizedDecision := strings.ToUpper(strings.TrimSpace(decision))
if strings.Contains(normalizedDecision, "ACCEPT") {
normalizedDecision = "ACCEPT"
} else if strings.Contains(normalizedDecision, "REJECT") {
normalizedDecision = "REJECT"
}
if normalizedDecision == "ACCEPT" || normalizedDecision == "REJECT" {
// Preserve original decision in explanation
if explanation, ok := jsonResponse["explanation"]; !ok {
jsonResponse["explanation"] = fmt.Sprintf("Original decision: %s\n", decision)
} else {
jsonResponse["explanation"] = fmt.Sprintf("Original decision: %s\n%s", decision, explanation)
}
// Parse nested JSON in explanation if present
if explanation, ok := jsonResponse["explanation"].(string); ok {
var nested map[string]interface{}
if err := json.Unmarshal([]byte(explanation), &nested); err == nil {
jsonResponse["explanation"] = nested
}
}
// Ensure consistent response format
response := map[string]interface{}{
"paper": map[string]interface{}{
"title": paper.Title,
"abstract": paper.Abstract,
"arxiv_id": paper.ArxivID,
},
"decision": normalizedDecision,
"explanation": jsonResponse["explanation"],
}
responseJSON, err := json.Marshal(response)
if err != nil {
return "", fmt.Errorf("failed to marshal response: %w", err)
}
duration := time.Since(startTime)
c.logger.Printf("Successfully evaluated paper: %s\n", paper.Title)
c.logger.Printf("Total time: %s\n", duration)
c.logger.Printf("Attempts: %d\n", attempt+1)
return string(responseJSON), nil
}
}
}
// If direct JSON parsing fails, try extracting from markdown code block
startIdx := bytes.Index([]byte(rawContent), []byte("```json"))
if startIdx >= 0 {
startIdx += len("```json")
endIdx := bytes.Index([]byte(rawContent[startIdx:]), []byte("```"))
if endIdx >= 0 {
jsonContent := rawContent[startIdx : startIdx+endIdx]
err = json.Unmarshal([]byte(jsonContent), &jsonResponse)
if err == nil {
if decision, ok := jsonResponse["decision"].(string); ok {
if decision == "ACCEPT" || decision == "REJECT" {
duration := time.Since(startTime)
c.logger.Printf("Successfully evaluated paper: %s\n", paper.Title)
c.logger.Printf("Total time: %s\n", duration)
c.logger.Printf("Attempts: %d\n", attempt+1)
return jsonContent, nil
}
}
}
}
}
// Fallback parsing if JSON is still invalid
decision := "ERROR"
if strings.Contains(rawContent, "ACCEPT") {
decision = "ACCEPT"
} else if strings.Contains(rawContent, "REJECT") {
decision = "REJECT"
}
// Create fallback response
fallbackResponse := map[string]interface{}{
"decision": decision,
"explanation": fmt.Sprintf("Original response: %s", rawContent),
}
fallbackJSON, _ := json.Marshal(fallbackResponse)
duration := time.Since(startTime)
c.logger.Printf("Fallback parsing used for paper: %s\n", paper.Title)
c.logger.Printf("Total time: %s\n", duration)
c.logger.Printf("Attempts: %d\n", attempt+1)
return string(fallbackJSON), nil
}
duration := time.Since(startTime)
c.logger.Printf("Failed to evaluate paper: %s\n", paper.Title)
c.logger.Printf("Total time: %s\n", duration)
c.logger.Printf("Attempts: %d\n", maxRetries)
return "", fmt.Errorf("max retries (%d) exceeded: %w", maxRetries, lastErr)
}