239 lines
7.9 KiB
Go
239 lines
7.9 KiB
Go
|
package main
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"encoding/json"
|
||
|
"flag"
|
||
|
"fmt"
|
||
|
"log"
|
||
|
"os"
|
||
|
"time"
|
||
|
|
||
|
"arxiv-processor/arxiv"
|
||
|
"json2md/lib"
|
||
|
"llm_processor/processor"
|
||
|
)
|
||
|
|
||
|
func main() {
|
||
|
// Set custom usage before defining flags
|
||
|
flag.Usage = func() {
|
||
|
fmt.Fprintf(os.Stderr, "Usage: %s [options]\n\n", os.Args[0])
|
||
|
fmt.Fprintf(os.Stderr, "A tool to fetch, filter, and process arXiv papers using LLM.\n\n")
|
||
|
fmt.Fprintf(os.Stderr, "Required flags:\n")
|
||
|
fmt.Fprintf(os.Stderr, " -criteria string\n\tPath to filter criteria file\n\n")
|
||
|
fmt.Fprintf(os.Stderr, "Source flags (must use either arXiv query OR input JSON):\n")
|
||
|
fmt.Fprintf(os.Stderr, " ArXiv query flags:\n")
|
||
|
fmt.Fprintf(os.Stderr, " -start string\n\tStart date in YYYYMMDD format\n")
|
||
|
fmt.Fprintf(os.Stderr, " -end string\n\tEnd date in YYYYMMDD format\n")
|
||
|
fmt.Fprintf(os.Stderr, " -search string\n\tarXiv category/search query (e.g., 'cs.AI', 'physics.comp-ph')\n")
|
||
|
fmt.Fprintf(os.Stderr, " -max-results int\n\tMaximum number of papers to retrieve (default: 100, max: 2000)\n\n")
|
||
|
fmt.Fprintf(os.Stderr, " OR\n\n")
|
||
|
fmt.Fprintf(os.Stderr, " Input JSON flag:\n")
|
||
|
fmt.Fprintf(os.Stderr, " -input-json string\n\tPath to input JSON file (bypasses arXiv fetch)\n\n")
|
||
|
fmt.Fprintf(os.Stderr, "Optional flags:\n")
|
||
|
fmt.Fprintf(os.Stderr, " -output string\n\tOutput markdown file path (default: papers.md)\n")
|
||
|
fmt.Fprintf(os.Stderr, " -model string\n\tLLM model to use (default: nvidia/llama-3.1-nemotron-70b-instruct)\n\n")
|
||
|
fmt.Fprintf(os.Stderr, "Environment variables:\n")
|
||
|
fmt.Fprintf(os.Stderr, " OPENROUTER_API_KEY\tRequired for LLM processing\n\n")
|
||
|
fmt.Fprintf(os.Stderr, "Examples:\n")
|
||
|
fmt.Fprintf(os.Stderr, " Fetch from arXiv:\n")
|
||
|
fmt.Fprintf(os.Stderr, " %s -start 20240101 -end 20240131 -search cs.AI -criteria criteria.txt -output papers.md\n\n", os.Args[0])
|
||
|
fmt.Fprintf(os.Stderr, " Use existing JSON:\n")
|
||
|
fmt.Fprintf(os.Stderr, " %s -input-json papers.json -criteria new-criteria.txt -output results.md\n", os.Args[0])
|
||
|
}
|
||
|
|
||
|
// CLI flags
|
||
|
start := flag.String("start", "", "Start date (YYYYMMDD)")
|
||
|
end := flag.String("end", "", "End date (YYYYMMDD)")
|
||
|
search := flag.String("search", "", "arXiv search query")
|
||
|
criteriaFile := flag.String("criteria", "", "Path to filter criteria file")
|
||
|
output := flag.String("output", "papers.md", "Output file path")
|
||
|
model := flag.String("model", "nvidia/llama-3.1-nemotron-70b-instruct", "LLM model to use")
|
||
|
maxResults := flag.Int("max-results", 100, "Maximum number of papers to retrieve (up to 2000)")
|
||
|
inputJSON := flag.String("input-json", "", "Path to input JSON file (bypasses arXiv fetch)")
|
||
|
|
||
|
flag.Parse()
|
||
|
|
||
|
// Validate flags
|
||
|
if *criteriaFile == "" {
|
||
|
fmt.Fprintf(os.Stderr, "Error: Missing required parameter: -criteria\n\n")
|
||
|
flag.Usage()
|
||
|
os.Exit(1)
|
||
|
}
|
||
|
|
||
|
// Validate either input-json is provided OR all arxiv flags are provided
|
||
|
usingInputJSON := *inputJSON != ""
|
||
|
usingArxiv := *start != "" || *end != "" || *search != ""
|
||
|
|
||
|
if usingInputJSON && usingArxiv {
|
||
|
fmt.Fprintf(os.Stderr, "Error: Cannot use both --input-json and arXiv query flags\n\n")
|
||
|
flag.Usage()
|
||
|
os.Exit(1)
|
||
|
}
|
||
|
|
||
|
if !usingInputJSON && !usingArxiv {
|
||
|
fmt.Fprintf(os.Stderr, "Error: Must provide either --input-json or arXiv query flags\n\n")
|
||
|
flag.Usage()
|
||
|
os.Exit(1)
|
||
|
}
|
||
|
|
||
|
if usingArxiv {
|
||
|
if *start == "" || *end == "" || *search == "" {
|
||
|
fmt.Fprintf(os.Stderr, "Error: Missing required arXiv parameters\n\n")
|
||
|
flag.Usage()
|
||
|
os.Exit(1)
|
||
|
}
|
||
|
|
||
|
if *maxResults <= 0 || *maxResults > 2000 {
|
||
|
fmt.Fprintf(os.Stderr, "Error: max-results must be between 1 and 2000\n\n")
|
||
|
flag.Usage()
|
||
|
os.Exit(1)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Configure logging
|
||
|
log.SetPrefix("[paper-system] ")
|
||
|
log.SetFlags(log.Ltime | log.Lmsgprefix)
|
||
|
|
||
|
ctx := context.Background()
|
||
|
|
||
|
// Paper type used for JSON operations
|
||
|
type LLMPaper struct {
|
||
|
Title string `json:"title"`
|
||
|
Abstract string `json:"abstract"`
|
||
|
ArxivID string `json:"arxiv_id"`
|
||
|
Authors []string `json:"authors"`
|
||
|
}
|
||
|
|
||
|
var llmPapers []LLMPaper
|
||
|
|
||
|
if usingInputJSON {
|
||
|
// Load papers from input JSON
|
||
|
log.Printf("Loading papers from %s", *inputJSON)
|
||
|
paperData, err := os.ReadFile(*inputJSON)
|
||
|
if err != nil {
|
||
|
log.Fatalf("Failed to read input JSON: %v", err)
|
||
|
}
|
||
|
|
||
|
if err := json.Unmarshal(paperData, &llmPapers); err != nil {
|
||
|
log.Fatalf("Failed to parse input JSON: %v", err)
|
||
|
}
|
||
|
log.Printf("Loaded %d papers from JSON", len(llmPapers))
|
||
|
|
||
|
} else {
|
||
|
// Fetch papers from arXiv
|
||
|
log.Printf("Fetching papers from arXiv for category %q between %s and %s", *search, *start, *end)
|
||
|
arxivClient := arxiv.NewClient()
|
||
|
|
||
|
startDate := parseDate(*start)
|
||
|
endDate := parseDate(*end)
|
||
|
|
||
|
query := arxiv.Query{
|
||
|
Category: *search,
|
||
|
DateRange: fmt.Sprintf("%s TO %s", startDate.Format("20060102"), endDate.Format("20060102")),
|
||
|
MaxResults: *maxResults,
|
||
|
StartOffset: 0,
|
||
|
}
|
||
|
|
||
|
log.Printf("Executing arXiv query: %+v", query)
|
||
|
papers, err := arxivClient.FetchPapers(ctx, query)
|
||
|
if err != nil {
|
||
|
log.Fatalf("arXiv fetch failed: %v", err)
|
||
|
}
|
||
|
log.Printf("Retrieved %d papers from arXiv", len(papers))
|
||
|
if len(papers) >= *maxResults {
|
||
|
log.Printf("WARNING: Retrieved maximum number of papers (%d). There may be more papers available.", *maxResults)
|
||
|
log.Printf("Use --max-results flag to retrieve more papers (up to 2000)")
|
||
|
}
|
||
|
|
||
|
// Convert arXiv papers to LLM format
|
||
|
llmPapers = make([]LLMPaper, len(papers))
|
||
|
for i, p := range papers {
|
||
|
// Convert author structs to string array
|
||
|
authors := make([]string, len(p.Authors))
|
||
|
for j, a := range p.Authors {
|
||
|
authors[j] = a.Name
|
||
|
}
|
||
|
|
||
|
llmPapers[i] = LLMPaper{
|
||
|
Title: p.Title,
|
||
|
Abstract: p.Summary,
|
||
|
ArxivID: p.ID,
|
||
|
Authors: authors,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Save papers to JSON for future use
|
||
|
log.Printf("Saving papers to papers.json")
|
||
|
papersJSON, err := json.Marshal(llmPapers)
|
||
|
if err != nil {
|
||
|
log.Fatalf("Failed to marshal papers: %v", err)
|
||
|
}
|
||
|
if err := os.WriteFile("papers.json", papersJSON, 0644); err != nil {
|
||
|
log.Fatalf("Failed to save papers JSON: %v", err)
|
||
|
}
|
||
|
log.Printf("Successfully saved papers to papers.json")
|
||
|
}
|
||
|
|
||
|
// Print paper titles for verification
|
||
|
log.Printf("Processing papers:")
|
||
|
for i, paper := range llmPapers {
|
||
|
log.Printf(" %d. %s", i+1, paper.Title)
|
||
|
}
|
||
|
|
||
|
// Save papers to temporary file for LLM processing
|
||
|
tempInput := "temp_input.json"
|
||
|
tempJSON, err := json.Marshal(llmPapers)
|
||
|
if err != nil {
|
||
|
log.Fatalf("Failed to marshal papers for LLM: %v", err)
|
||
|
}
|
||
|
if err := os.WriteFile(tempInput, tempJSON, 0644); err != nil {
|
||
|
log.Fatalf("Failed to save temp input JSON: %v", err)
|
||
|
}
|
||
|
|
||
|
// Filter papers with LLM
|
||
|
log.Printf("Starting LLM processing")
|
||
|
apiKey := os.Getenv("OPENROUTER_API_KEY")
|
||
|
if apiKey == "" {
|
||
|
log.Fatal("OPENROUTER_API_KEY environment variable is required")
|
||
|
}
|
||
|
|
||
|
llmProcessor := processor.NewProcessor(*model, 32, apiKey) // 32 = batch size from README
|
||
|
log.Printf("Initialized LLM processor with model %s", *model)
|
||
|
|
||
|
tempOutput := "temp_output.json"
|
||
|
log.Printf("Processing papers with criteria from %s", *criteriaFile)
|
||
|
if err := llmProcessor.ProcessPapers(ctx, tempInput, tempOutput, *criteriaFile, 1*time.Second); err != nil {
|
||
|
log.Fatalf("LLM processing failed: %v", err)
|
||
|
}
|
||
|
log.Printf("LLM processing complete, results saved to %s", tempOutput)
|
||
|
|
||
|
// Generate markdown
|
||
|
log.Printf("Generating markdown output")
|
||
|
decisions, err := lib.ProcessJSONFile(tempOutput)
|
||
|
if err != nil {
|
||
|
log.Fatalf("Failed to process JSON: %v", err)
|
||
|
}
|
||
|
log.Printf("Processed decisions: %d accepted, %d rejected", len(decisions.Accepted), len(decisions.Rejected))
|
||
|
|
||
|
if err := lib.GenerateMarkdown(decisions, *output); err != nil {
|
||
|
log.Fatalf("Markdown generation failed: %v", err)
|
||
|
}
|
||
|
log.Printf("Generated markdown output at %s", *output)
|
||
|
|
||
|
// Cleanup temp files
|
||
|
os.Remove(tempInput)
|
||
|
os.Remove(tempOutput)
|
||
|
log.Printf("Cleaned up temporary files")
|
||
|
|
||
|
log.Printf("Process complete. Results saved to %s", *output)
|
||
|
}
|
||
|
|
||
|
func parseDate(s string) time.Time {
|
||
|
t, err := time.Parse("20060102", s)
|
||
|
if err != nil {
|
||
|
log.Fatalf("Invalid date %q: %v", s, err)
|
||
|
}
|
||
|
return t
|
||
|
}
|