255 lines
8.7 KiB
Go
255 lines
8.7 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"time"
|
|
|
|
"arxiv-processor/arxiv"
|
|
"json2md/lib"
|
|
"llm_processor/processor"
|
|
)
|
|
|
|
func main() {
|
|
// Set custom usage before defining flags
|
|
flag.Usage = func() {
|
|
fmt.Fprintf(os.Stderr, "Usage: %s [options]\n\n", os.Args[0])
|
|
fmt.Fprintf(os.Stderr, "A tool to fetch, filter, and process arXiv papers using LLM.\n\n")
|
|
fmt.Fprintf(os.Stderr, "Required flags:\n")
|
|
fmt.Fprintf(os.Stderr, " -criteria string\n\tPath to filter criteria file\n\n")
|
|
fmt.Fprintf(os.Stderr, "Source flags (must use either arXiv query OR input JSON):\n")
|
|
fmt.Fprintf(os.Stderr, " ArXiv query flags:\n")
|
|
fmt.Fprintf(os.Stderr, " -start string\n\tStart date in YYYYMMDD format\n")
|
|
fmt.Fprintf(os.Stderr, " -end string\n\tEnd date in YYYYMMDD format\n")
|
|
fmt.Fprintf(os.Stderr, " -search string\n\tarXiv category/search query (e.g., 'cat:cs.AI', 'au:kording')\n")
|
|
fmt.Fprintf(os.Stderr, " * see arXiv API docs for search types https://info.arxiv.org/help/api/user-manual.html#51-details-of-query-construction)\n")
|
|
fmt.Fprintf(os.Stderr, " -max-results int\n\tMaximum number of papers to retrieve (default: 100, max: 2000)\n\n")
|
|
fmt.Fprintf(os.Stderr, " OR\n\n")
|
|
fmt.Fprintf(os.Stderr, " Input JSON flag:\n")
|
|
fmt.Fprintf(os.Stderr, " -input-json string\n\tPath to input JSON file (bypasses arXiv fetch)\n\n")
|
|
fmt.Fprintf(os.Stderr, "Optional flags:\n")
|
|
fmt.Fprintf(os.Stderr, " -output string\n\tOutput markdown file path (default: auto-dated format when using arXiv query)\n")
|
|
fmt.Fprintf(os.Stderr, " -model string\n\tLLM model to use (default: nvidia/llama-3.1-nemotron-70b-instruct)\n\n")
|
|
fmt.Fprintf(os.Stderr, "Environment variables:\n")
|
|
fmt.Fprintf(os.Stderr, " OPENROUTER_API_KEY\tRequired for LLM processing\n\n")
|
|
fmt.Fprintf(os.Stderr, "Examples:\n")
|
|
fmt.Fprintf(os.Stderr, " Fetch from arXiv:\n")
|
|
fmt.Fprintf(os.Stderr, " %s -start 20240101 -end 20240131 -search cs.AI -criteria criteria.txt\n", os.Args[0])
|
|
fmt.Fprintf(os.Stderr, " Outputs: 20240101-20240131-cs.AI-papers.json/.md\n\n")
|
|
fmt.Fprintf(os.Stderr, " Use existing JSON:\n")
|
|
fmt.Fprintf(os.Stderr, " %s -input-json 20240101-20240131-cs.AI-papers.json -criteria new-criteria.txt -output custom-name.md\n", os.Args[0])
|
|
}
|
|
|
|
// CLI flags
|
|
start := flag.String("start", "", "Start date (YYYYMMDD)")
|
|
end := flag.String("end", "", "End date (YYYYMMDD)")
|
|
search := flag.String("search", "", "arXiv search query")
|
|
criteriaFile := flag.String("criteria", "", "Path to filter criteria file")
|
|
output := flag.String("output", "papers.md", "Output file path")
|
|
model := flag.String("model", "nvidia/llama-3.1-nemotron-70b-instruct", "LLM model to use")
|
|
maxResults := flag.Int("max-results", 100, "Maximum number of papers to retrieve (up to 2000)")
|
|
inputJSON := flag.String("input-json", "", "Path to input JSON file (bypasses arXiv fetch)")
|
|
|
|
flag.Parse()
|
|
|
|
// Validate flags
|
|
if *criteriaFile == "" {
|
|
fmt.Fprintf(os.Stderr, "Error: Missing required parameter: -criteria\n\n")
|
|
flag.Usage()
|
|
os.Exit(1)
|
|
}
|
|
|
|
// Validate either input-json is provided OR all arxiv flags are provided
|
|
usingInputJSON := *inputJSON != ""
|
|
usingArxiv := *start != "" || *end != "" || *search != ""
|
|
|
|
if usingInputJSON && usingArxiv {
|
|
fmt.Fprintf(os.Stderr, "Error: Cannot use both --input-json and arXiv query flags\n\n")
|
|
flag.Usage()
|
|
os.Exit(1)
|
|
}
|
|
|
|
if !usingInputJSON && !usingArxiv {
|
|
fmt.Fprintf(os.Stderr, "Error: Must provide either --input-json or arXiv query flags\n\n")
|
|
flag.Usage()
|
|
os.Exit(1)
|
|
}
|
|
|
|
if usingArxiv {
|
|
if *start == "" || *end == "" || *search == "" {
|
|
fmt.Fprintf(os.Stderr, "Error: Missing required arXiv parameters\n\n")
|
|
flag.Usage()
|
|
os.Exit(1)
|
|
}
|
|
|
|
if *maxResults <= 0 || *maxResults > 2000 {
|
|
fmt.Fprintf(os.Stderr, "Error: max-results must be between 1 and 2000\n\n")
|
|
flag.Usage()
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
// Configure logging
|
|
log.SetPrefix("[paper-system] ")
|
|
log.SetFlags(log.Ltime | log.Lmsgprefix)
|
|
|
|
ctx := context.Background()
|
|
|
|
// Paper type used for JSON operations
|
|
type LLMPaper struct {
|
|
Title string `json:"title"`
|
|
Abstract string `json:"abstract"`
|
|
ArxivID string `json:"arxiv_id"`
|
|
Authors []string `json:"authors"`
|
|
}
|
|
|
|
var llmPapers []LLMPaper
|
|
|
|
if usingInputJSON {
|
|
// Load papers from input JSON
|
|
log.Printf("Loading papers from %s", *inputJSON)
|
|
paperData, err := os.ReadFile(*inputJSON)
|
|
if err != nil {
|
|
log.Fatalf("Failed to read input JSON: %v", err)
|
|
}
|
|
|
|
if err := json.Unmarshal(paperData, &llmPapers); err != nil {
|
|
log.Fatalf("Failed to parse input JSON: %v", err)
|
|
}
|
|
log.Printf("Loaded %d papers from JSON", len(llmPapers))
|
|
|
|
} else {
|
|
// Fetch papers from arXiv
|
|
log.Printf("Fetching papers from arXiv for category %q between %s and %s", *search, *start, *end)
|
|
arxivClient := arxiv.NewClient()
|
|
|
|
startDate := parseDate(*start)
|
|
endDate := parseDate(*end)
|
|
|
|
query := arxiv.Query{
|
|
Category: *search,
|
|
DateRange: fmt.Sprintf("%s TO %s", startDate.Format("20060102"), endDate.Format("20060102")),
|
|
MaxResults: *maxResults,
|
|
StartOffset: 0,
|
|
}
|
|
|
|
log.Printf("Executing arXiv query: %+v", query)
|
|
papers, err := arxivClient.FetchPapers(ctx, query)
|
|
if err != nil {
|
|
log.Fatalf("arXiv fetch failed: %v", err)
|
|
}
|
|
log.Printf("Retrieved %d papers from arXiv", len(papers))
|
|
if len(papers) >= *maxResults {
|
|
log.Printf("WARNING: Retrieved maximum number of papers (%d). There may be more papers available.", *maxResults)
|
|
log.Printf("Use --max-results flag to retrieve more papers (up to 2000)")
|
|
}
|
|
|
|
// Convert arXiv papers to LLM format
|
|
llmPapers = make([]LLMPaper, len(papers))
|
|
for i, p := range papers {
|
|
// Convert author structs to string array
|
|
authors := make([]string, len(p.Authors))
|
|
for j, a := range p.Authors {
|
|
authors[j] = a.Name
|
|
}
|
|
|
|
llmPapers[i] = LLMPaper{
|
|
Title: p.Title,
|
|
Abstract: p.Summary,
|
|
ArxivID: p.ID,
|
|
Authors: authors,
|
|
}
|
|
}
|
|
|
|
// Save papers to JSON for future use
|
|
log.Printf("Saving papers to papers.json")
|
|
papersJSON, err := json.Marshal(llmPapers)
|
|
if err != nil {
|
|
log.Fatalf("Failed to marshal papers: %v", err)
|
|
}
|
|
jsonName := fmt.Sprintf("%s-%s-%s-papers.json", *start, *end, *search)
|
|
if err := os.WriteFile(jsonName, papersJSON, 0644); err != nil {
|
|
log.Fatalf("Failed to save papers JSON: %v", err)
|
|
}
|
|
log.Printf("Successfully saved papers to %s", jsonName)
|
|
}
|
|
|
|
// Print paper titles for verification
|
|
log.Printf("Processing papers:")
|
|
for i, paper := range llmPapers {
|
|
log.Printf(" %d. %s", i+1, paper.Title)
|
|
}
|
|
|
|
// Save papers to temporary file for LLM processing
|
|
tempInput, err := os.CreateTemp("", "arxiv-process-*.tmp")
|
|
if err != nil {
|
|
log.Fatalf("Failed to create temp file: %v", err)
|
|
}
|
|
defer os.Remove(tempInput.Name())
|
|
|
|
tempJSON, err := json.Marshal(llmPapers)
|
|
if err != nil {
|
|
log.Fatalf("Failed to marshal papers for LLM: %v", err)
|
|
}
|
|
if err := os.WriteFile(tempInput.Name(), tempJSON, 0644); err != nil {
|
|
log.Fatalf("Failed to save temp input JSON: %v", err)
|
|
}
|
|
|
|
// Filter papers with LLM
|
|
log.Printf("Starting LLM processing")
|
|
apiKey := os.Getenv("OPENROUTER_API_KEY")
|
|
if apiKey == "" {
|
|
log.Fatal("OPENROUTER_API_KEY environment variable is required")
|
|
}
|
|
|
|
llmProcessor := processor.NewProcessor(*model, 32, apiKey) // 32 = batch size from README
|
|
log.Printf("Initialized LLM processor with model %s", *model)
|
|
|
|
tempOutput, err := os.CreateTemp("", "arxiv-process-*.tmp")
|
|
if err != nil {
|
|
log.Fatalf("Failed to create temp file: %v", err)
|
|
}
|
|
defer os.Remove(tempOutput.Name())
|
|
|
|
log.Printf("Processing papers with criteria from %s", *criteriaFile)
|
|
if err := llmProcessor.ProcessPapers(ctx, tempInput.Name(), tempOutput.Name(), *criteriaFile, 1*time.Second); err != nil {
|
|
log.Fatalf("LLM processing failed: %v", err)
|
|
}
|
|
log.Printf("LLM processing complete, results saved to %s", tempOutput.Name())
|
|
|
|
// Generate markdown
|
|
log.Printf("Generating markdown output")
|
|
decisions, err := lib.ProcessJSONFile(tempOutput.Name())
|
|
if err != nil {
|
|
log.Fatalf("Failed to process JSON: %v", err)
|
|
}
|
|
log.Printf("Processed decisions: %d accepted, %d rejected", len(decisions.Accepted), len(decisions.Rejected))
|
|
|
|
defaultOutput := fmt.Sprintf("%s-%s-%s-papers.md", *start, *end, *search)
|
|
if *output == "papers.md" && usingArxiv {
|
|
*output = defaultOutput
|
|
}
|
|
|
|
if err := lib.GenerateMarkdown(decisions, *output); err != nil {
|
|
log.Fatalf("Markdown generation failed: %v", err)
|
|
}
|
|
log.Printf("Generated markdown output at %s", *output)
|
|
|
|
// Temp files cleaned up by defer statements
|
|
log.Printf("Cleanup complete")
|
|
|
|
log.Printf("Process complete. Results saved to %s", *output)
|
|
}
|
|
|
|
func parseDate(s string) time.Time {
|
|
t, err := time.Parse("20060102", s)
|
|
if err != nil {
|
|
log.Fatalf("Invalid date %q: %v", s, err)
|
|
}
|
|
return t
|
|
}
|