paper-system/main.go

255 lines
8.7 KiB
Go

package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"log"
"os"
"time"
"arxiv-processor/arxiv"
"json2md/lib"
"llm_processor/processor"
)
func main() {
// Set custom usage before defining flags
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: %s [options]\n\n", os.Args[0])
fmt.Fprintf(os.Stderr, "A tool to fetch, filter, and process arXiv papers using LLM.\n\n")
fmt.Fprintf(os.Stderr, "Required flags:\n")
fmt.Fprintf(os.Stderr, " -criteria string\n\tPath to filter criteria file\n\n")
fmt.Fprintf(os.Stderr, "Source flags (must use either arXiv query OR input JSON):\n")
fmt.Fprintf(os.Stderr, " ArXiv query flags:\n")
fmt.Fprintf(os.Stderr, " -start string\n\tStart date in YYYYMMDD format\n")
fmt.Fprintf(os.Stderr, " -end string\n\tEnd date in YYYYMMDD format\n")
fmt.Fprintf(os.Stderr, " -search string\n\tarXiv category/search query (e.g., 'cat:cs.AI', 'au:kording')\n")
fmt.Fprintf(os.Stderr, " * see arXiv API docs for search types https://info.arxiv.org/help/api/user-manual.html#51-details-of-query-construction)\n")
fmt.Fprintf(os.Stderr, " -max-results int\n\tMaximum number of papers to retrieve (default: 100, max: 2000)\n\n")
fmt.Fprintf(os.Stderr, " OR\n\n")
fmt.Fprintf(os.Stderr, " Input JSON flag:\n")
fmt.Fprintf(os.Stderr, " -input-json string\n\tPath to input JSON file (bypasses arXiv fetch)\n\n")
fmt.Fprintf(os.Stderr, "Optional flags:\n")
fmt.Fprintf(os.Stderr, " -output string\n\tOutput markdown file path (default: auto-dated format when using arXiv query)\n")
fmt.Fprintf(os.Stderr, " -model string\n\tLLM model to use (default: nvidia/llama-3.1-nemotron-70b-instruct)\n\n")
fmt.Fprintf(os.Stderr, "Environment variables:\n")
fmt.Fprintf(os.Stderr, " OPENROUTER_API_KEY\tRequired for LLM processing\n\n")
fmt.Fprintf(os.Stderr, "Examples:\n")
fmt.Fprintf(os.Stderr, " Fetch from arXiv:\n")
fmt.Fprintf(os.Stderr, " %s -start 20240101 -end 20240131 -search cs.AI -criteria criteria.txt\n", os.Args[0])
fmt.Fprintf(os.Stderr, " Outputs: 20240101-20240131-cs.AI-papers.json/.md\n\n")
fmt.Fprintf(os.Stderr, " Use existing JSON:\n")
fmt.Fprintf(os.Stderr, " %s -input-json 20240101-20240131-cs.AI-papers.json -criteria new-criteria.txt -output custom-name.md\n", os.Args[0])
}
// CLI flags
start := flag.String("start", "", "Start date (YYYYMMDD)")
end := flag.String("end", "", "End date (YYYYMMDD)")
search := flag.String("search", "", "arXiv search query")
criteriaFile := flag.String("criteria", "", "Path to filter criteria file")
output := flag.String("output", "papers.md", "Output file path")
model := flag.String("model", "nvidia/llama-3.1-nemotron-70b-instruct", "LLM model to use")
maxResults := flag.Int("max-results", 100, "Maximum number of papers to retrieve (up to 2000)")
inputJSON := flag.String("input-json", "", "Path to input JSON file (bypasses arXiv fetch)")
flag.Parse()
// Validate flags
if *criteriaFile == "" {
fmt.Fprintf(os.Stderr, "Error: Missing required parameter: -criteria\n\n")
flag.Usage()
os.Exit(1)
}
// Validate either input-json is provided OR all arxiv flags are provided
usingInputJSON := *inputJSON != ""
usingArxiv := *start != "" || *end != "" || *search != ""
if usingInputJSON && usingArxiv {
fmt.Fprintf(os.Stderr, "Error: Cannot use both --input-json and arXiv query flags\n\n")
flag.Usage()
os.Exit(1)
}
if !usingInputJSON && !usingArxiv {
fmt.Fprintf(os.Stderr, "Error: Must provide either --input-json or arXiv query flags\n\n")
flag.Usage()
os.Exit(1)
}
if usingArxiv {
if *start == "" || *end == "" || *search == "" {
fmt.Fprintf(os.Stderr, "Error: Missing required arXiv parameters\n\n")
flag.Usage()
os.Exit(1)
}
if *maxResults <= 0 || *maxResults > 2000 {
fmt.Fprintf(os.Stderr, "Error: max-results must be between 1 and 2000\n\n")
flag.Usage()
os.Exit(1)
}
}
// Configure logging
log.SetPrefix("[paper-system] ")
log.SetFlags(log.Ltime | log.Lmsgprefix)
ctx := context.Background()
// Paper type used for JSON operations
type LLMPaper struct {
Title string `json:"title"`
Abstract string `json:"abstract"`
ArxivID string `json:"arxiv_id"`
Authors []string `json:"authors"`
}
var llmPapers []LLMPaper
if usingInputJSON {
// Load papers from input JSON
log.Printf("Loading papers from %s", *inputJSON)
paperData, err := os.ReadFile(*inputJSON)
if err != nil {
log.Fatalf("Failed to read input JSON: %v", err)
}
if err := json.Unmarshal(paperData, &llmPapers); err != nil {
log.Fatalf("Failed to parse input JSON: %v", err)
}
log.Printf("Loaded %d papers from JSON", len(llmPapers))
} else {
// Fetch papers from arXiv
log.Printf("Fetching papers from arXiv for category %q between %s and %s", *search, *start, *end)
arxivClient := arxiv.NewClient()
startDate := parseDate(*start)
endDate := parseDate(*end)
query := arxiv.Query{
Category: *search,
DateRange: fmt.Sprintf("%s TO %s", startDate.Format("20060102"), endDate.Format("20060102")),
MaxResults: *maxResults,
StartOffset: 0,
}
log.Printf("Executing arXiv query: %+v", query)
papers, err := arxivClient.FetchPapers(ctx, query)
if err != nil {
log.Fatalf("arXiv fetch failed: %v", err)
}
log.Printf("Retrieved %d papers from arXiv", len(papers))
if len(papers) >= *maxResults {
log.Printf("WARNING: Retrieved maximum number of papers (%d). There may be more papers available.", *maxResults)
log.Printf("Use --max-results flag to retrieve more papers (up to 2000)")
}
// Convert arXiv papers to LLM format
llmPapers = make([]LLMPaper, len(papers))
for i, p := range papers {
// Convert author structs to string array
authors := make([]string, len(p.Authors))
for j, a := range p.Authors {
authors[j] = a.Name
}
llmPapers[i] = LLMPaper{
Title: p.Title,
Abstract: p.Summary,
ArxivID: p.ID,
Authors: authors,
}
}
// Save papers to JSON for future use
log.Printf("Saving papers to papers.json")
papersJSON, err := json.Marshal(llmPapers)
if err != nil {
log.Fatalf("Failed to marshal papers: %v", err)
}
jsonName := fmt.Sprintf("%s-%s-%s-papers.json", *start, *end, *search)
if err := os.WriteFile(jsonName, papersJSON, 0644); err != nil {
log.Fatalf("Failed to save papers JSON: %v", err)
}
log.Printf("Successfully saved papers to %s", jsonName)
}
// Print paper titles for verification
log.Printf("Processing papers:")
for i, paper := range llmPapers {
log.Printf(" %d. %s", i+1, paper.Title)
}
// Save papers to temporary file for LLM processing
tempInput, err := os.CreateTemp("", "arxiv-process-*.tmp")
if err != nil {
log.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tempInput.Name())
tempJSON, err := json.Marshal(llmPapers)
if err != nil {
log.Fatalf("Failed to marshal papers for LLM: %v", err)
}
if err := os.WriteFile(tempInput.Name(), tempJSON, 0644); err != nil {
log.Fatalf("Failed to save temp input JSON: %v", err)
}
// Filter papers with LLM
log.Printf("Starting LLM processing")
apiKey := os.Getenv("OPENROUTER_API_KEY")
if apiKey == "" {
log.Fatal("OPENROUTER_API_KEY environment variable is required")
}
llmProcessor := processor.NewProcessor(*model, 32, apiKey) // 32 = batch size from README
log.Printf("Initialized LLM processor with model %s", *model)
tempOutput, err := os.CreateTemp("", "arxiv-process-*.tmp")
if err != nil {
log.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tempOutput.Name())
log.Printf("Processing papers with criteria from %s", *criteriaFile)
if err := llmProcessor.ProcessPapers(ctx, tempInput.Name(), tempOutput.Name(), *criteriaFile, 1*time.Second); err != nil {
log.Fatalf("LLM processing failed: %v", err)
}
log.Printf("LLM processing complete, results saved to %s", tempOutput.Name())
// Generate markdown
log.Printf("Generating markdown output")
decisions, err := lib.ProcessJSONFile(tempOutput.Name())
if err != nil {
log.Fatalf("Failed to process JSON: %v", err)
}
log.Printf("Processed decisions: %d accepted, %d rejected", len(decisions.Accepted), len(decisions.Rejected))
defaultOutput := fmt.Sprintf("%s-%s-%s-papers.md", *start, *end, *search)
if *output == "papers.md" && usingArxiv {
*output = defaultOutput
}
if err := lib.GenerateMarkdown(decisions, *output); err != nil {
log.Fatalf("Markdown generation failed: %v", err)
}
log.Printf("Generated markdown output at %s", *output)
// Temp files cleaned up by defer statements
log.Printf("Cleanup complete")
log.Printf("Process complete. Results saved to %s", *output)
}
func parseDate(s string) time.Time {
t, err := time.Parse("20060102", s)
if err != nil {
log.Fatalf("Invalid date %q: %v", s, err)
}
return t
}