191 lines
5.5 KiB
Go
191 lines
5.5 KiB
Go
package processor
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"llm_processor/client"
|
|
"llm_processor/models"
|
|
"llm_processor/storage"
|
|
)
|
|
|
|
type Processor struct {
|
|
modelName string
|
|
batchSize int
|
|
timeout time.Duration
|
|
client *client.OpenRouterClient
|
|
}
|
|
|
|
// SetTimeout allows changing the processor timeout
|
|
func (p *Processor) SetTimeout(timeout time.Duration) {
|
|
p.timeout = timeout
|
|
}
|
|
|
|
func NewProcessor(modelName string, batchSize int, apiKey string) *Processor {
|
|
return &Processor{
|
|
modelName: modelName,
|
|
batchSize: batchSize,
|
|
timeout: 3600 * time.Second, // Default 1 hour timeout
|
|
client: client.NewOpenRouterClient(apiKey),
|
|
}
|
|
}
|
|
|
|
func (p *Processor) ProcessPapers(parentCtx context.Context, inputPath, outputPath, criteriaPath string, delay time.Duration) error {
|
|
startTime := time.Now()
|
|
ctx, cancel := context.WithTimeout(parentCtx, p.timeout)
|
|
defer cancel()
|
|
|
|
// Load papers from input file
|
|
papers, err := storage.LoadPapers(inputPath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load papers: %w", err)
|
|
}
|
|
|
|
// Set criteria path in storage
|
|
storage.SetCriteriaPath(criteriaPath)
|
|
|
|
// Initialize results file
|
|
if err := storage.InitializeResultsFile(outputPath); err != nil {
|
|
return fmt.Errorf("failed to initialize results file: %w", err)
|
|
}
|
|
|
|
// Process papers in batches
|
|
for i := 0; i < len(papers); i += p.batchSize {
|
|
end := i + p.batchSize
|
|
if end > len(papers) {
|
|
end = len(papers)
|
|
}
|
|
|
|
batch := papers[i:end]
|
|
if err := p.processBatch(ctx, batch, delay, startTime, outputPath); err != nil {
|
|
return fmt.Errorf("failed to process batch: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *Processor) processBatch(ctx context.Context, papers []models.Paper, delay time.Duration, startTime time.Time, outputPath string) error {
|
|
ctx, cancel := context.WithTimeout(ctx, p.timeout)
|
|
defer cancel()
|
|
|
|
// Load criteria
|
|
criteria, err := storage.LoadCriteria()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load criteria: %w", err)
|
|
}
|
|
for i, paper := range papers {
|
|
var evaluation string
|
|
var lastErr error
|
|
|
|
// Retry up to 3 times with exponential backoff
|
|
for attempt := 0; attempt < 3; attempt++ {
|
|
evaluation, lastErr = p.client.EvaluatePaper(ctx, paper, criteria, p.modelName)
|
|
if lastErr == nil {
|
|
break
|
|
}
|
|
time.Sleep(time.Duration(attempt+1) * time.Second) // Exponential backoff
|
|
}
|
|
|
|
if lastErr != nil {
|
|
// Log error but continue with next paper
|
|
evaluation = fmt.Sprintf(`{
|
|
"decision": "ERROR",
|
|
"explanation": "Failed to evaluate paper after 3 attempts: %v"
|
|
}`, lastErr)
|
|
}
|
|
|
|
// Parse and validate evaluation response
|
|
var evalResponse struct {
|
|
Decision string `json:"decision"`
|
|
Explanation string `json:"explanation"`
|
|
}
|
|
|
|
if err := json.Unmarshal([]byte(evaluation), &evalResponse); err != nil {
|
|
// Try to extract decision and explanation using regex
|
|
decision := "ERROR"
|
|
explanation := evaluation
|
|
|
|
if strings.Contains(evaluation, `"decision": "ACCEPT"`) || strings.Contains(evaluation, `"decision":"ACCEPT"`) {
|
|
decision = "ACCEPT"
|
|
} else if strings.Contains(evaluation, `"decision": "REJECT"`) || strings.Contains(evaluation, `"decision":"REJECT"`) {
|
|
decision = "REJECT"
|
|
}
|
|
|
|
// Try to extract just the explanation if it's in JSON format
|
|
if strings.Contains(evaluation, `"explanation"`) {
|
|
parts := strings.Split(evaluation, `"explanation"`)
|
|
if len(parts) > 1 {
|
|
// Find the content between the first : and the next "
|
|
expl := parts[1]
|
|
start := strings.Index(expl, ":")
|
|
if start != -1 {
|
|
expl = expl[start+1:]
|
|
// Remove leading/trailing whitespace and quotes
|
|
expl = strings.Trim(expl, " \t\n\r\"")
|
|
// Remove trailing JSON syntax
|
|
expl = strings.TrimRight(expl, "}")
|
|
expl = strings.TrimRight(expl, ",")
|
|
explanation = expl
|
|
}
|
|
}
|
|
}
|
|
|
|
evalResponse = struct {
|
|
Decision string `json:"decision"`
|
|
Explanation string `json:"explanation"`
|
|
}{
|
|
Decision: decision,
|
|
Explanation: explanation,
|
|
}
|
|
}
|
|
|
|
// Sanitize the explanation
|
|
explanation := evalResponse.Explanation
|
|
// Remove any markdown code block syntax
|
|
explanation = strings.ReplaceAll(explanation, "```", "")
|
|
// Remove any JSON formatting if the explanation is a raw JSON string
|
|
if strings.HasPrefix(strings.TrimSpace(explanation), "{") {
|
|
var jsonExpl struct {
|
|
Explanation string `json:"explanation"`
|
|
}
|
|
if err := json.Unmarshal([]byte(explanation), &jsonExpl); err == nil && jsonExpl.Explanation != "" {
|
|
explanation = jsonExpl.Explanation
|
|
}
|
|
}
|
|
// Escape any remaining special markdown characters
|
|
explanation = strings.ReplaceAll(explanation, "*", "\\*")
|
|
explanation = strings.ReplaceAll(explanation, "_", "\\_")
|
|
explanation = strings.ReplaceAll(explanation, "`", "\\`")
|
|
|
|
result := models.Result{
|
|
Paper: paper,
|
|
Decision: evalResponse.Decision,
|
|
Explanation: explanation,
|
|
}
|
|
|
|
// Save result with detailed logging
|
|
fmt.Printf("Saving result for paper %q to %s\n", paper.Title, outputPath)
|
|
if err := storage.SaveResult(result, outputPath); err != nil {
|
|
fmt.Printf("Failed to save result for paper %q: %v\n", paper.Title, err)
|
|
return fmt.Errorf("failed to save result: %w", err)
|
|
}
|
|
fmt.Printf("Successfully saved result for paper %q\n", paper.Title)
|
|
|
|
// Print progress
|
|
elapsed := time.Since(startTime).Seconds()
|
|
fmt.Printf("Processed paper %d/%d (%s) - Total runtime: %.2f seconds\n",
|
|
i+1, len(papers), paper.Title, elapsed)
|
|
|
|
// Apply delay between papers if specified
|
|
if delay > 0 {
|
|
time.Sleep(delay)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|