paperprocessor/paperprocessor.go

298 lines
7.8 KiB
Go

package paperprocessor
import (
"bytes"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"strings"
"time"
)
// Paper represents a single academic paper
type Paper struct {
Title string `json:"title"`
Abstract string `json:"abstract"`
ArxivID string `json:"arxiv_id"`
}
// PaperResult represents the decision for a single paper
type PaperResult struct {
Paper Paper `json:"paper"`
Decision string `json:"decision"`
Explanation string `json:"explanation"`
}
// ProcessingResult represents the final output structure
type ProcessingResult struct {
Accepted []PaperResult `json:"accepted"`
Rejected []PaperResult `json:"rejected"`
Failed []struct {
Paper Paper `json:"paper"`
Error string `json:"error"`
Output string `json:"output"`
} `json:"failed"`
}
// Config holds the configuration for the processor
type Config struct {
APIEndpoint string
APIKey string
Model string
RequestDelay time.Duration // Delay between API requests
}
// Processor handles the paper processing workflow
type Processor struct {
config Config
}
// NewProcessor creates a new processor instance
func NewProcessor(config Config) (*Processor, error) {
// Set default delay if not specified
if config.RequestDelay == 0 {
config.RequestDelay = time.Second // Default 1 second delay
}
// Validate required configuration
if config.APIKey == "" {
return nil, fmt.Errorf("API key is required")
}
if config.APIEndpoint == "" {
return nil, fmt.Errorf("API endpoint is required")
}
if config.Model == "" {
return nil, fmt.Errorf("model name is required")
}
return &Processor{
config: config,
}, nil
}
// ProcessPapers processes a list of papers against given criteria
func (p *Processor) ProcessPapers(papers []Paper, criteria string) (*ProcessingResult, error) {
result := &ProcessingResult{
Accepted: make([]PaperResult, 0),
Rejected: make([]PaperResult, 0),
}
for i, paper := range papers {
if i > 0 { // Don't delay before the first request
time.Sleep(p.config.RequestDelay)
}
decision, err := p.evaluatePaper(paper, criteria)
if err != nil {
// Instead of returning error, add to failed list
result.Failed = append(result.Failed, struct {
Paper Paper `json:"paper"`
Error string `json:"error"`
Output string `json:"output"`
}{
Paper: paper,
Error: err.Error(),
Output: "", // We could potentially add the raw LLM output here if needed
})
continue
}
paperResult := PaperResult{
Paper: paper,
Decision: decision.Decision,
Explanation: decision.Explanation,
}
if decision.Decision == "ACCEPT" {
result.Accepted = append(result.Accepted, paperResult)
} else {
result.Rejected = append(result.Rejected, paperResult)
}
}
// Write failed analyses to dump file if any exist
if len(result.Failed) > 0 {
dumpData, err := json.MarshalIndent(result.Failed, "", " ")
if err == nil { // Only try to write if marshaling succeeded
ioutil.WriteFile("dump.json", dumpData, 0644)
}
}
return result, nil
}
type llmRequest struct {
Model string `json:"model"`
Messages []message `json:"messages"`
}
type message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type llmResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
type decisionResult struct {
Decision string
Explanation string
}
func (p *Processor) evaluatePaper(paper Paper, criteria string) (*decisionResult, error) {
prompt := fmt.Sprintf(`Please evaluate the following academic paper against the provided criteria.
Respond with either "ACCEPT" or "REJECT" followed by an explanation of your decision.
For ACCEPT decisions, provide a thorough explanation. For REJECT decisions, keep the explanation brief and focused on the key reason.
Do not use markdown, bullet points, or quotes in your response. Keep your response clear and concise.
Your response should be in the format:
DECISION
Explanation
Criteria:
%s
Paper Title: %s
Abstract: %s`, criteria, paper.Title, paper.Abstract)
reqBody := llmRequest{
Model: p.config.Model,
Messages: []message{
{
Role: "user",
Content: prompt,
},
},
}
reqJSON, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("error marshaling request: %v", err)
}
req, err := http.NewRequest("POST", p.config.APIEndpoint, bytes.NewBuffer(reqJSON))
if err != nil {
return nil, fmt.Errorf("error creating request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", p.config.APIKey))
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("error making request: %v", err)
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("error reading response: %v", err)
}
var llmResp llmResponse
if err := json.Unmarshal(body, &llmResp); err != nil {
return nil, fmt.Errorf("error unmarshaling response: %v", err)
}
if len(llmResp.Choices) == 0 {
return nil, fmt.Errorf("no response from LLM")
}
content := llmResp.Choices[0].Message.Content
// Find first line with ACCEPT/REJECT
var decisionLine string
lines := bytes.Split([]byte(content), []byte("\n"))
for _, line := range lines {
if strings.Contains(strings.ToUpper(string(line)), "ACCEPT") ||
strings.Contains(strings.ToUpper(string(line)), "REJECT") {
decisionLine = string(line)
break
}
}
if decisionLine == "" {
return nil, fmt.Errorf("no decision found in response. Full response:\n%s", content)
}
// Clean and normalize decision
rawDecision := strings.TrimSpace(decisionLine)
// Handle common prefixes and clean the decision text
cleanDecision := rawDecision
for _, prefix := range []string{"DECISION:", "Decision:", "-", "\"", "*"} {
cleanDecision = strings.TrimPrefix(cleanDecision, prefix)
}
cleanDecision = strings.TrimSpace(cleanDecision)
// Remove any remaining quotes
cleanDecision = strings.Trim(cleanDecision, "\"")
// Normalize case
upperDecision := strings.ToUpper(cleanDecision)
var decision string
switch {
case strings.HasPrefix(upperDecision, "ACCEPT"):
decision = "ACCEPT"
case strings.HasPrefix(upperDecision, "REJECT"):
decision = "REJECT"
default:
return nil, fmt.Errorf("invalid decision value: %q (cleaned: %q). Full response:\n%s",
rawDecision, cleanDecision, content)
}
// Get explanation as everything after the decision line
explanation := strings.TrimSpace(strings.Replace(content, decisionLine, "", 1))
return &decisionResult{
Decision: decision,
Explanation: explanation,
}, nil
}
// ProcessFile processes a JSON file containing papers and writes results to an output file
func ProcessFile(inputPath, outputPath, criteriaPath string, config Config) error {
// Read input papers
inputData, err := ioutil.ReadFile(inputPath)
if err != nil {
return fmt.Errorf("error reading input file: %v", err)
}
var papers []Paper
if err := json.Unmarshal(inputData, &papers); err != nil {
return fmt.Errorf("error parsing input JSON: %v", err)
}
// Read criteria
criteriaData, err := ioutil.ReadFile(criteriaPath)
if err != nil {
return fmt.Errorf("error reading criteria file: %v", err)
}
// Process papers
processor, err := NewProcessor(config)
if err != nil {
return fmt.Errorf("error creating processor: %v", err)
}
result, err := processor.ProcessPapers(papers, string(criteriaData))
if err != nil {
return fmt.Errorf("error processing papers: %v", err)
}
// Write results
outputData, err := json.MarshalIndent(result, "", " ")
if err != nil {
return fmt.Errorf("error marshaling output JSON: %v", err)
}
if err := ioutil.WriteFile(outputPath, outputData, 0644); err != nil {
return fmt.Errorf("error writing output file: %v", err)
}
return nil
}