paper-system/llm_processor/processor/processor.go

191 lines
5.5 KiB
Go
Raw Normal View History

2025-01-24 15:26:47 +00:00
package processor
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"llm_processor/client"
"llm_processor/models"
"llm_processor/storage"
)
type Processor struct {
modelName string
batchSize int
timeout time.Duration
client *client.OpenRouterClient
}
// SetTimeout allows changing the processor timeout
func (p *Processor) SetTimeout(timeout time.Duration) {
p.timeout = timeout
}
func NewProcessor(modelName string, batchSize int, apiKey string) *Processor {
return &Processor{
modelName: modelName,
batchSize: batchSize,
timeout: 3600 * time.Second, // Default 1 hour timeout
client: client.NewOpenRouterClient(apiKey),
}
}
func (p *Processor) ProcessPapers(parentCtx context.Context, inputPath, outputPath, criteriaPath string, delay time.Duration) error {
startTime := time.Now()
ctx, cancel := context.WithTimeout(parentCtx, p.timeout)
defer cancel()
// Load papers from input file
papers, err := storage.LoadPapers(inputPath)
if err != nil {
return fmt.Errorf("failed to load papers: %w", err)
}
// Set criteria path in storage
storage.SetCriteriaPath(criteriaPath)
// Initialize results file
if err := storage.InitializeResultsFile(outputPath); err != nil {
return fmt.Errorf("failed to initialize results file: %w", err)
}
// Process papers in batches
for i := 0; i < len(papers); i += p.batchSize {
end := i + p.batchSize
if end > len(papers) {
end = len(papers)
}
batch := papers[i:end]
if err := p.processBatch(ctx, batch, delay, startTime, outputPath); err != nil {
return fmt.Errorf("failed to process batch: %w", err)
}
}
return nil
}
func (p *Processor) processBatch(ctx context.Context, papers []models.Paper, delay time.Duration, startTime time.Time, outputPath string) error {
ctx, cancel := context.WithTimeout(ctx, p.timeout)
defer cancel()
// Load criteria
criteria, err := storage.LoadCriteria()
if err != nil {
return fmt.Errorf("failed to load criteria: %w", err)
}
for i, paper := range papers {
var evaluation string
var lastErr error
// Retry up to 3 times with exponential backoff
for attempt := 0; attempt < 3; attempt++ {
evaluation, lastErr = p.client.EvaluatePaper(ctx, paper, criteria, p.modelName)
if lastErr == nil {
break
}
time.Sleep(time.Duration(attempt+1) * time.Second) // Exponential backoff
}
if lastErr != nil {
// Log error but continue with next paper
evaluation = fmt.Sprintf(`{
"decision": "ERROR",
"explanation": "Failed to evaluate paper after 3 attempts: %v"
}`, lastErr)
}
// Parse and validate evaluation response
var evalResponse struct {
Decision string `json:"decision"`
Explanation string `json:"explanation"`
}
if err := json.Unmarshal([]byte(evaluation), &evalResponse); err != nil {
// Try to extract decision and explanation using regex
decision := "ERROR"
explanation := evaluation
if strings.Contains(evaluation, `"decision": "ACCEPT"`) || strings.Contains(evaluation, `"decision":"ACCEPT"`) {
decision = "ACCEPT"
} else if strings.Contains(evaluation, `"decision": "REJECT"`) || strings.Contains(evaluation, `"decision":"REJECT"`) {
decision = "REJECT"
}
// Try to extract just the explanation if it's in JSON format
if strings.Contains(evaluation, `"explanation"`) {
parts := strings.Split(evaluation, `"explanation"`)
if len(parts) > 1 {
// Find the content between the first : and the next "
expl := parts[1]
start := strings.Index(expl, ":")
if start != -1 {
expl = expl[start+1:]
// Remove leading/trailing whitespace and quotes
expl = strings.Trim(expl, " \t\n\r\"")
// Remove trailing JSON syntax
expl = strings.TrimRight(expl, "}")
expl = strings.TrimRight(expl, ",")
explanation = expl
}
}
}
evalResponse = struct {
Decision string `json:"decision"`
Explanation string `json:"explanation"`
}{
Decision: decision,
Explanation: explanation,
}
}
// Sanitize the explanation
explanation := evalResponse.Explanation
// Remove any markdown code block syntax
explanation = strings.ReplaceAll(explanation, "```", "")
// Remove any JSON formatting if the explanation is a raw JSON string
if strings.HasPrefix(strings.TrimSpace(explanation), "{") {
var jsonExpl struct {
Explanation string `json:"explanation"`
}
if err := json.Unmarshal([]byte(explanation), &jsonExpl); err == nil && jsonExpl.Explanation != "" {
explanation = jsonExpl.Explanation
}
}
// Escape any remaining special markdown characters
explanation = strings.ReplaceAll(explanation, "*", "\\*")
explanation = strings.ReplaceAll(explanation, "_", "\\_")
explanation = strings.ReplaceAll(explanation, "`", "\\`")
result := models.Result{
Paper: paper,
Decision: evalResponse.Decision,
Explanation: explanation,
}
// Save result with detailed logging
fmt.Printf("Saving result for paper %q to %s\n", paper.Title, outputPath)
if err := storage.SaveResult(result, outputPath); err != nil {
fmt.Printf("Failed to save result for paper %q: %v\n", paper.Title, err)
return fmt.Errorf("failed to save result: %w", err)
}
fmt.Printf("Successfully saved result for paper %q\n", paper.Title)
// Print progress
elapsed := time.Since(startTime).Seconds()
fmt.Printf("Processed paper %d/%d (%s) - Total runtime: %.2f seconds\n",
i+1, len(papers), paper.Title, elapsed)
// Apply delay between papers if specified
if delay > 0 {
time.Sleep(delay)
}
}
return nil
}