diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a6c57f5 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +*.json diff --git a/arxiva.go b/arxiva.go index a019b05..ba4266e 100644 --- a/arxiva.go +++ b/arxiva.go @@ -11,6 +11,8 @@ import ( "time" ) +var httpClient = &http.Client{} + type Paper struct { Title string `json:"title"` Abstract string `json:"abstract"` @@ -63,7 +65,11 @@ func FetchPapers(startDate string, endDate string, query string, maxPapers int) params.Add("search_query", searchQuery) params.Add("max_results", fmt.Sprintf("%d", maxPapers)) - resp, err := http.Get(baseURL + "?" + params.Encode()) + req, err := http.NewRequest("GET", baseURL+"?"+params.Encode(), nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %v", err) + } + resp, err := httpClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to fetch from arXiv: %v", err) } @@ -93,6 +99,10 @@ func FetchPapers(startDate string, endDate string, query string, maxPapers int) } func SaveToFile(papers []Paper, startDate string, endDate string, query string) error { + if len(papers) == 0 { + return fmt.Errorf("no papers to save") + } + filename := fmt.Sprintf("%s-%s-%s.json", startDate, endDate, SanitizeFilename(query)) data, err := json.MarshalIndent(papers, "", " ") diff --git a/arxiva_test.go b/arxiva_test.go new file mode 100644 index 0000000..25a1070 --- /dev/null +++ b/arxiva_test.go @@ -0,0 +1,173 @@ +package arxiva + +import ( + "bytes" + "encoding/json" + "errors" + "io" + "net/http" + "os" + "testing" +) + +// Mock HTTP client that implements the http.RoundTripper interface +type mockTransport struct { + response *http.Response + err error +} + +func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) { + return m.response, m.err +} + +func TestFetchPapers(t *testing.T) { + tests := []struct { + name string + mockResponse string + mockError error + expectError bool + }{ + { + name: "successful fetch", + mockResponse: ` + + + Test Paper + Test summary + http://arxiv.org/abs/1234.56789 + +`, + expectError: false, + }, + { + name: "api error", + mockError: errors.New("API unavailable"), + expectError: true, + }, + { + name: "invalid json", + mockResponse: "invalid json", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup mock HTTP client + var client *http.Client + var response *http.Response + if tt.mockError == nil { + response = &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(tt.mockResponse)), + } + } + client = &http.Client{ + Transport: &mockTransport{ + response: response, + err: tt.mockError, + }, + } + + // Replace default client with mock + originalClient := httpClient + httpClient = client + defer func() { httpClient = originalClient }() + + // Test function + papers, err := FetchPapers("20240101", "20240125", "test", 1) + if err != nil { + t.Logf("FetchPapers error: %v", err) + } + if (err != nil) != tt.expectError { + t.Errorf("FetchPapers() error = %v, expectError %v", err, tt.expectError) + return + } + + if !tt.expectError && len(papers) == 0 { + t.Error("Expected papers but got none") + } + }) + } +} + +func TestSaveToFile(t *testing.T) { + tests := []struct { + name string + papers []Paper + expectError bool + }{ + { + name: "successful save", + papers: []Paper{ + { + Title: "Test Paper", + Abstract: "Test abstract", + ArxivID: "1234.56789", + }, + }, + expectError: false, + }, + { + name: "empty papers", + papers: []Paper{}, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create temp file + tmpFile, err := os.CreateTemp("", "test-*.json") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpFile.Name()) + + // Test function + filename := "20240101-20240125-test.json" + err = SaveToFile(tt.papers, "20240101", "20240125", "test") + if err != nil && !tt.expectError { + t.Errorf("SaveToFile() unexpected error: %v", err) + return + } + if err == nil && tt.expectError { + t.Errorf("SaveToFile() expected error but got none") + return + } + if tt.expectError { + return + } + + // Clean up test file + defer os.Remove(filename) + if (err != nil) != tt.expectError { + t.Errorf("SaveToFile() error = %v, expectError %v", err, tt.expectError) + return + } + + // Verify file contents if successful + if !tt.expectError { + var savedPapers []Paper + fileData, err := os.ReadFile(filename) + if err != nil { + t.Errorf("Failed to read saved file: %v", err) + return + } + + if err := json.Unmarshal(fileData, &savedPapers); err != nil { + t.Errorf("Failed to unmarshal saved file: %v", err) + return + } + + if len(savedPapers) != len(tt.papers) { + t.Errorf("Expected %d papers, got %d", len(tt.papers), len(savedPapers)) + } + } else if tt.name == "empty papers" { + if _, err := os.Stat(filename); !os.IsNotExist(err) { + t.Errorf("File should not exist for empty papers case") + } + } + }) + } +}