papers/papers.go

263 lines
8.8 KiB
Go

package main
import (
"encoding/json"
"flag"
"fmt"
"io"
"log"
"os"
"regexp"
"strings"
"time"
"gitea.r8z.us/stwhite/arxiva"
"gitea.r8z.us/stwhite/paperformatter"
"gitea.r8z.us/stwhite/paperprocessor"
)
// Paper represents the expected structure of papers in the input JSON file
type Paper struct {
Title string `json:"title"`
Abstract string `json:"abstract"`
ArxivID string `json:"arxiv_id"`
}
// validateInputFile checks if the input file exists and has valid JSON structure
func validateInputFile(path string) ([]Paper, error) {
file, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("failed to open input file: %v", err)
}
defer file.Close()
content, err := io.ReadAll(file)
if err != nil {
return nil, fmt.Errorf("failed to read input file: %v", err)
}
var papers []Paper
if err := json.Unmarshal(content, &papers); err != nil {
return nil, fmt.Errorf("invalid JSON format: %v", err)
}
// Validate required fields
for i, paper := range papers {
if paper.Title == "" {
return nil, fmt.Errorf("paper at index %d missing title", i)
}
if paper.Abstract == "" {
return nil, fmt.Errorf("paper at index %d missing abstract", i)
}
if paper.ArxivID == "" {
return nil, fmt.Errorf("paper at index %d missing arxiv_id", i)
}
}
return papers, nil
}
// sanitizeFilename replaces invalid filename characters to match arxiva's sanitization
func sanitizeFilename(s string) string {
s = strings.ReplaceAll(s, ":", "_")
s = strings.ReplaceAll(s, " ", "_")
return s
}
// isValidDate checks if the date string is in YYYYMMDD format
func isValidDate(date string) bool {
// Check basic format with regex
matched, err := regexp.MatchString(`^\d{8}$`, date)
if err != nil || !matched {
return false
}
// Parse date to verify it's a valid date
_, err = time.Parse("20060102", date)
return err == nil
}
func main() {
// Set custom usage message
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: %s [options]\n\n", os.Args[0])
fmt.Fprintf(os.Stderr, "Description:\n")
fmt.Fprintf(os.Stderr, " Fetches papers from arXiv (or uses input file), processes them using an LLM, and generates both JSON and Markdown outputs.\n\n")
fmt.Fprintf(os.Stderr, "Pipeline:\n")
fmt.Fprintf(os.Stderr, " 1. Either:\n")
fmt.Fprintf(os.Stderr, " a) Fetches papers from arXiv based on date range and query, or\n")
fmt.Fprintf(os.Stderr, " b) Uses papers from provided input file\n")
fmt.Fprintf(os.Stderr, " 2. Processes papers using specified LLM model\n")
fmt.Fprintf(os.Stderr, " 3. Formats results to both JSON and Markdown\n\n")
fmt.Fprintf(os.Stderr, "Required flags:\n")
fmt.Fprintf(os.Stderr, " -api-key : API key for LLM service\n\n")
fmt.Fprintf(os.Stderr, "Required for arXiv fetching (if not using -input):\n")
fmt.Fprintf(os.Stderr, " -start : Start date (YYYYMMDD)\n")
fmt.Fprintf(os.Stderr, " -end : End date (YYYYMMDD)\n")
fmt.Fprintf(os.Stderr, " -query : Search query\n\n")
fmt.Fprintf(os.Stderr, "Options:\n")
flag.PrintDefaults()
fmt.Fprintf(os.Stderr, "\nExamples:\n")
fmt.Fprintf(os.Stderr, " Using arXiv:\n")
fmt.Fprintf(os.Stderr, " %s -start 20240101 -end 20240131 -query \"machine learning\" -api-key \"your-key\"\n\n", os.Args[0])
fmt.Fprintf(os.Stderr, " Using input file:\n")
fmt.Fprintf(os.Stderr, " %s -input papers.json -api-key \"your-key\"\n\n", os.Args[0])
fmt.Fprintf(os.Stderr, " With custom options:\n")
fmt.Fprintf(os.Stderr, " %s -input papers.json -api-key \"your-key\" -model \"gpt-4\" -json-output \"results.json\" -md-output \"summary.md\"\n", os.Args[0])
fmt.Fprintf(os.Stderr, " Search only:\n")
fmt.Fprintf(os.Stderr, " %s -search-only -start 20240101 -end 20240131 -query \"machine learning\" \n\n", os.Args[0])
}
// Parse command line arguments
searchOnly := flag.Bool("search-only", false, "Only fetch papers from arXiv and save to JSON file (do not process)")
inputFile := flag.String("input", "", "Input JSON file containing papers (optional)")
startDate := flag.String("start", "", "Start date in YYYYMMDD format (required if not using -input)")
endDate := flag.String("end", "", "End date in YYYYMMDD format (required if not using -input)")
query := flag.String("query", "", "Search query (required if not using -input)")
maxResults := flag.Int("maxResults", 100, "Maximum number of results (1-2000)")
model := flag.String("model", "phi-4", "Model to use for processing")
apiKey := flag.String("api-key", "", "API key for service authentication")
apiEndpoint := flag.String("api-endpoint", "http://localhost:1234/v1/chat/completions", "API endpoint URL")
criteriaFile := flag.String("criteria", "criteria.md", "Path to evaluation criteria markdown file")
jsonOutput := flag.String("json-output", "", "JSON output file path (default: YYYYMMDD-YYYYMMDD-query.json)")
mdOutput := flag.String("md-output", "", "Markdown output file path (default: YYYYMMDD-YYYYMMDD-query.md)")
flag.Parse()
// Validate required flags and input
if *searchOnly {
if *startDate == "" || *endDate == "" || *query == "" {
fmt.Fprintf(os.Stderr, "Error: start date, end date, and query are required when using -search-only\n\n")
flag.Usage()
os.Exit(1)
}
// Validate date format
if !isValidDate(*startDate) || !isValidDate(*endDate) {
fmt.Fprintf(os.Stderr, "Error: dates must be in YYYYMMDD format\n")
os.Exit(1)
}
// Validate maxResults range
if *maxResults < 1 || *maxResults > 2000 {
fmt.Fprintf(os.Stderr, "Error: maxResults must be between 1 and 2000\n")
os.Exit(1)
}
// Fetch papers from arXiv
papers, err := arxiva.FetchPapers(*startDate, *endDate, *query, *maxResults)
if err != nil {
log.Fatalf("Failed to fetch papers: %v", err)
}
// Save papers to JSON file using the same naming convention
if err := arxiva.SaveToFile(papers, *startDate, *endDate, *query); err != nil {
log.Fatalf("Failed to save papers: %v", err)
}
log.Printf("Successfully fetched and saved papers to %s-%s-%s.json", *startDate, *endDate, sanitizeFilename(*query))
os.Exit(0)
}
var (
papers []arxiva.Paper
err error
baseFilename string
)
if *inputFile != "" {
// Use input file
inputPapers, err := validateInputFile(*inputFile)
if err != nil {
log.Fatalf("Invalid input file: %v", err)
}
// Convert input papers to arxiva.Paper format
papers = make([]arxiva.Paper, len(inputPapers))
for i, p := range inputPapers {
papers[i] = arxiva.Paper{
Title: p.Title,
Abstract: p.Abstract,
ArxivID: p.ArxivID,
}
}
// Use input filename as base for outputs
baseFilename = *inputFile
if ext := ".json"; strings.HasSuffix(baseFilename, ext) {
baseFilename = baseFilename[:len(baseFilename)-len(ext)]
}
} else {
// Validate arXiv fetching parameters
if *startDate == "" || *endDate == "" || *query == "" {
fmt.Fprintf(os.Stderr, "Error: start date, end date, and query are required when not using -input\n\n")
flag.Usage()
os.Exit(1)
}
// Validate date format
if !isValidDate(*startDate) || !isValidDate(*endDate) {
fmt.Fprintf(os.Stderr, "Error: dates must be in YYYYMMDD format\n")
os.Exit(1)
}
// Validate maxResults range
if *maxResults < 1 || *maxResults > 2000 {
fmt.Fprintf(os.Stderr, "Error: maxResults must be between 1 and 2000\n")
os.Exit(1)
}
// Fetch papers from arXiv
papers, err = arxiva.FetchPapers(*startDate, *endDate, *query, *maxResults)
if err != nil {
log.Fatalf("Failed to fetch papers: %v", err)
}
// Save papers to JSON file using the same naming convention
if err := arxiva.SaveToFile(papers, *startDate, *endDate, *query); err != nil {
log.Fatalf("Failed to save papers: %v", err)
}
baseFilename = fmt.Sprintf("%s-%s-%s", *startDate, *endDate, sanitizeFilename(*query))
}
// Create processor configuration
config := paperprocessor.Config{
APIEndpoint: *apiEndpoint,
APIKey: *apiKey,
Model: *model,
RequestDelay: 2 * time.Second,
}
// Get criteria filename without extension for output naming
criteriaBase := *criteriaFile
if ext := ".md"; strings.HasSuffix(criteriaBase, ext) {
criteriaBase = criteriaBase[:len(criteriaBase)-len(ext)]
}
// Set default output filenames if not provided
if *jsonOutput == "" {
*jsonOutput = fmt.Sprintf("%s-%s.json", baseFilename, criteriaBase)
}
if *mdOutput == "" {
*mdOutput = fmt.Sprintf("%s-%s.md", baseFilename, criteriaBase)
}
// Process the papers
inputJson := baseFilename + ".json"
if err := paperprocessor.ProcessFile(
inputJson,
*jsonOutput,
*criteriaFile,
config,
); err != nil {
log.Fatalf("Processing failed: %v", err)
}
// Format the processed results to markdown
if err := paperformatter.FormatPapers(*jsonOutput, *mdOutput); err != nil {
log.Fatalf("Formatting failed: %v", err)
}
log.Printf("Successfully processed papers. Results written to %s and formatted to %s", *jsonOutput, *mdOutput)
}