327 lines
9.7 KiB
Go
327 lines
9.7 KiB
Go
|
package client
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"context"
|
||
|
"encoding/json"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"log"
|
||
|
"math"
|
||
|
"net"
|
||
|
"net/http"
|
||
|
"os"
|
||
|
"strings"
|
||
|
"time"
|
||
|
|
||
|
"llm_processor/models"
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
openRouterURL = "https://openrouter.ai/api/v1/chat/completions"
|
||
|
maxRetries = 5
|
||
|
initialDelay = 1 * time.Second
|
||
|
evaluationTimeout = 15 * time.Minute
|
||
|
requestTimeout = 5 * time.Minute
|
||
|
connectionTimeout = 2 * time.Minute
|
||
|
)
|
||
|
|
||
|
type OpenRouterClient struct {
|
||
|
apiKey string
|
||
|
httpClient *http.Client
|
||
|
logger *log.Logger
|
||
|
createClient func() *http.Client
|
||
|
}
|
||
|
|
||
|
func NewOpenRouterClient(apiKey string) *OpenRouterClient {
|
||
|
logFile, err := os.OpenFile("debug.log", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||
|
if err != nil {
|
||
|
log.Fatalf("Failed to open debug.log: %v", err)
|
||
|
}
|
||
|
|
||
|
logger := log.New(io.MultiWriter(os.Stdout, logFile), "", log.LstdFlags|log.Lshortfile)
|
||
|
logger.Println("Initializing OpenRouter client")
|
||
|
|
||
|
createClient := func() *http.Client {
|
||
|
transport := &http.Transport{
|
||
|
MaxIdleConns: 100,
|
||
|
MaxIdleConnsPerHost: 100,
|
||
|
IdleConnTimeout: 30 * time.Second,
|
||
|
DialContext: (&net.Dialer{
|
||
|
Timeout: connectionTimeout,
|
||
|
KeepAlive: 30 * time.Second,
|
||
|
}).DialContext,
|
||
|
TLSHandshakeTimeout: 10 * time.Second,
|
||
|
}
|
||
|
|
||
|
return &http.Client{
|
||
|
Timeout: requestTimeout,
|
||
|
Transport: transport,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
client := &OpenRouterClient{
|
||
|
apiKey: apiKey,
|
||
|
httpClient: createClient(),
|
||
|
logger: logger,
|
||
|
createClient: createClient,
|
||
|
}
|
||
|
|
||
|
return client
|
||
|
}
|
||
|
|
||
|
type ChatCompletionRequest struct {
|
||
|
Model string `json:"model"`
|
||
|
Messages []ChatMessage `json:"messages"`
|
||
|
}
|
||
|
|
||
|
type ChatMessage struct {
|
||
|
Role string `json:"role"`
|
||
|
Content string `json:"content"`
|
||
|
}
|
||
|
|
||
|
type ChatCompletionResponse struct {
|
||
|
Choices []struct {
|
||
|
Message ChatMessage `json:"message"`
|
||
|
} `json:"choices"`
|
||
|
}
|
||
|
|
||
|
func (c *OpenRouterClient) EvaluatePaper(ctx context.Context, paper models.Paper, criteria string, model string) (string, error) {
|
||
|
// Create a new context with evaluation timeout
|
||
|
evalCtx, cancel := context.WithTimeout(ctx, evaluationTimeout)
|
||
|
defer cancel()
|
||
|
|
||
|
startTime := time.Now()
|
||
|
c.logger.Printf("Starting evaluation for paper: %s\n", paper.Title)
|
||
|
c.logger.Printf("Evaluation timeout: %s\n", evaluationTimeout)
|
||
|
fmt.Printf("Starting evaluation for paper: %s at %s\n", paper.Title, startTime.Format(time.RFC3339))
|
||
|
|
||
|
prompt := fmt.Sprintf(`Evaluate this paper based on the following criteria:
|
||
|
%s
|
||
|
|
||
|
Paper Title: %s
|
||
|
Abstract: %s
|
||
|
|
||
|
Respond ONLY with a JSON object in this exact format:
|
||
|
{
|
||
|
"decision": "ACCEPT or REJECT",
|
||
|
"explanation": "Your explanation here"
|
||
|
}
|
||
|
|
||
|
Do not include any other information in your response.
|
||
|
|
||
|
IMPORTANT:
|
||
|
1. The decision MUST be either "ACCEPT" or "REJECT" (uppercase)
|
||
|
2. The explanation should be a clear, concise reason for your decision
|
||
|
3. Do not include any text outside the JSON object
|
||
|
4. Ensure the response is valid JSON (proper quotes and escaping)
|
||
|
5. Do not include any markdown or formatting`, criteria, paper.Title, paper.Abstract)
|
||
|
|
||
|
reqBody := ChatCompletionRequest{
|
||
|
Model: model,
|
||
|
Messages: []ChatMessage{
|
||
|
{
|
||
|
Role: "system",
|
||
|
Content: "You are a research paper evaluator. Respond only with the requested JSON format.",
|
||
|
},
|
||
|
{
|
||
|
Role: "user",
|
||
|
Content: prompt,
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
|
||
|
var lastErr error
|
||
|
for attempt := 0; attempt < maxRetries; attempt++ {
|
||
|
attemptStart := time.Now()
|
||
|
c.logger.Printf("Attempt %d started at %s\n", attempt+1, attemptStart.Format(time.RFC3339))
|
||
|
if attempt > 0 {
|
||
|
delay := time.Duration(math.Pow(2, float64(attempt))) * initialDelay
|
||
|
select {
|
||
|
case <-time.After(delay):
|
||
|
case <-ctx.Done():
|
||
|
return "", ctx.Err()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
jsonBody, err := json.Marshal(reqBody)
|
||
|
if err != nil {
|
||
|
return "", fmt.Errorf("failed to marshal request body: %w", err)
|
||
|
}
|
||
|
|
||
|
req, err := http.NewRequestWithContext(evalCtx, "POST", openRouterURL, bytes.NewBuffer(jsonBody))
|
||
|
if err != nil {
|
||
|
return "", fmt.Errorf("failed to create request: %w", err)
|
||
|
}
|
||
|
|
||
|
req.Header.Set("Authorization", "Bearer "+c.apiKey)
|
||
|
req.Header.Set("Content-Type", "application/json")
|
||
|
req.Header.Set("HTTP-Referer", "https://github.com/stwhite/arvix")
|
||
|
req.Header.Set("X-Title", "ArXiv Paper Processor")
|
||
|
|
||
|
resp, err := c.httpClient.Do(req)
|
||
|
if err != nil {
|
||
|
// Log the specific error type
|
||
|
c.logger.Printf("Attempt %d error: %v\n", attempt+1, err)
|
||
|
|
||
|
// Handle context cancellation/timeout
|
||
|
if errors.Is(err, context.DeadlineExceeded) {
|
||
|
c.logger.Printf("Context deadline exceeded, retrying...\n")
|
||
|
lastErr = fmt.Errorf("context deadline exceeded")
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
// On timeout errors, create a new client
|
||
|
if strings.Contains(err.Error(), "timeout") {
|
||
|
c.logger.Printf("Timeout detected, recreating HTTP client...\n")
|
||
|
c.httpClient = c.createClient()
|
||
|
}
|
||
|
|
||
|
lastErr = fmt.Errorf("attempt %d: %w", attempt+1, err)
|
||
|
continue
|
||
|
}
|
||
|
defer resp.Body.Close()
|
||
|
|
||
|
if resp.StatusCode != http.StatusOK {
|
||
|
body, _ := io.ReadAll(resp.Body)
|
||
|
lastErr = fmt.Errorf("attempt %d: openrouter request failed: %s - %s", attempt+1, resp.Status, string(body))
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
var completionResp ChatCompletionResponse
|
||
|
if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
|
||
|
lastErr = fmt.Errorf("attempt %d: failed to decode response: %w", attempt+1, err)
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
if len(completionResp.Choices) == 0 {
|
||
|
lastErr = fmt.Errorf("attempt %d: no choices in response", attempt+1)
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
rawContent := completionResp.Choices[0].Message.Content
|
||
|
|
||
|
// Try to parse as JSON first
|
||
|
var jsonResponse map[string]interface{}
|
||
|
err = json.Unmarshal([]byte(rawContent), &jsonResponse)
|
||
|
if err != nil {
|
||
|
// If direct JSON parsing fails, try extracting from markdown code block
|
||
|
startIdx := bytes.Index([]byte(rawContent), []byte("```json"))
|
||
|
if startIdx >= 0 {
|
||
|
startIdx += len("```json")
|
||
|
endIdx := bytes.Index([]byte(rawContent[startIdx:]), []byte("```"))
|
||
|
if endIdx >= 0 {
|
||
|
jsonContent := rawContent[startIdx : startIdx+endIdx]
|
||
|
err = json.Unmarshal([]byte(jsonContent), &jsonResponse)
|
||
|
if err != nil {
|
||
|
// If still failing, try to parse as raw JSON without code block
|
||
|
err = json.Unmarshal([]byte(rawContent), &jsonResponse)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if err == nil {
|
||
|
// Validate and normalize decision
|
||
|
if decision, ok := jsonResponse["decision"].(string); ok {
|
||
|
// Normalize decision value
|
||
|
normalizedDecision := strings.ToUpper(strings.TrimSpace(decision))
|
||
|
if strings.Contains(normalizedDecision, "ACCEPT") {
|
||
|
normalizedDecision = "ACCEPT"
|
||
|
} else if strings.Contains(normalizedDecision, "REJECT") {
|
||
|
normalizedDecision = "REJECT"
|
||
|
}
|
||
|
|
||
|
if normalizedDecision == "ACCEPT" || normalizedDecision == "REJECT" {
|
||
|
// Preserve original decision in explanation
|
||
|
if explanation, ok := jsonResponse["explanation"]; !ok {
|
||
|
jsonResponse["explanation"] = fmt.Sprintf("Original decision: %s\n", decision)
|
||
|
} else {
|
||
|
jsonResponse["explanation"] = fmt.Sprintf("Original decision: %s\n%s", decision, explanation)
|
||
|
}
|
||
|
|
||
|
// Parse nested JSON in explanation if present
|
||
|
if explanation, ok := jsonResponse["explanation"].(string); ok {
|
||
|
var nested map[string]interface{}
|
||
|
if err := json.Unmarshal([]byte(explanation), &nested); err == nil {
|
||
|
jsonResponse["explanation"] = nested
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Ensure consistent response format
|
||
|
response := map[string]interface{}{
|
||
|
"paper": map[string]interface{}{
|
||
|
"title": paper.Title,
|
||
|
"abstract": paper.Abstract,
|
||
|
"arxiv_id": paper.ArxivID,
|
||
|
},
|
||
|
"decision": normalizedDecision,
|
||
|
"explanation": jsonResponse["explanation"],
|
||
|
}
|
||
|
|
||
|
responseJSON, err := json.Marshal(response)
|
||
|
if err != nil {
|
||
|
return "", fmt.Errorf("failed to marshal response: %w", err)
|
||
|
}
|
||
|
|
||
|
duration := time.Since(startTime)
|
||
|
c.logger.Printf("Successfully evaluated paper: %s\n", paper.Title)
|
||
|
c.logger.Printf("Total time: %s\n", duration)
|
||
|
c.logger.Printf("Attempts: %d\n", attempt+1)
|
||
|
return string(responseJSON), nil
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// If direct JSON parsing fails, try extracting from markdown code block
|
||
|
startIdx := bytes.Index([]byte(rawContent), []byte("```json"))
|
||
|
if startIdx >= 0 {
|
||
|
startIdx += len("```json")
|
||
|
endIdx := bytes.Index([]byte(rawContent[startIdx:]), []byte("```"))
|
||
|
if endIdx >= 0 {
|
||
|
jsonContent := rawContent[startIdx : startIdx+endIdx]
|
||
|
err = json.Unmarshal([]byte(jsonContent), &jsonResponse)
|
||
|
if err == nil {
|
||
|
if decision, ok := jsonResponse["decision"].(string); ok {
|
||
|
if decision == "ACCEPT" || decision == "REJECT" {
|
||
|
duration := time.Since(startTime)
|
||
|
c.logger.Printf("Successfully evaluated paper: %s\n", paper.Title)
|
||
|
c.logger.Printf("Total time: %s\n", duration)
|
||
|
c.logger.Printf("Attempts: %d\n", attempt+1)
|
||
|
return jsonContent, nil
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Fallback parsing if JSON is still invalid
|
||
|
decision := "ERROR"
|
||
|
if strings.Contains(rawContent, "ACCEPT") {
|
||
|
decision = "ACCEPT"
|
||
|
} else if strings.Contains(rawContent, "REJECT") {
|
||
|
decision = "REJECT"
|
||
|
}
|
||
|
|
||
|
// Create fallback response
|
||
|
fallbackResponse := map[string]interface{}{
|
||
|
"decision": decision,
|
||
|
"explanation": fmt.Sprintf("Original response: %s", rawContent),
|
||
|
}
|
||
|
fallbackJSON, _ := json.Marshal(fallbackResponse)
|
||
|
|
||
|
duration := time.Since(startTime)
|
||
|
c.logger.Printf("Fallback parsing used for paper: %s\n", paper.Title)
|
||
|
c.logger.Printf("Total time: %s\n", duration)
|
||
|
c.logger.Printf("Attempts: %d\n", attempt+1)
|
||
|
return string(fallbackJSON), nil
|
||
|
}
|
||
|
|
||
|
duration := time.Since(startTime)
|
||
|
c.logger.Printf("Failed to evaluate paper: %s\n", paper.Title)
|
||
|
c.logger.Printf("Total time: %s\n", duration)
|
||
|
c.logger.Printf("Attempts: %d\n", maxRetries)
|
||
|
return "", fmt.Errorf("max retries (%d) exceeded: %w", maxRetries, lastErr)
|
||
|
}
|