Add unit tests

This commit is contained in:
Steve White 2025-01-25 11:35:15 -06:00
parent d2190124f9
commit 6f7eea0373
3 changed files with 185 additions and 1 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
*.json

View File

@ -11,6 +11,8 @@ import (
"time"
)
var httpClient = &http.Client{}
type Paper struct {
Title string `json:"title"`
Abstract string `json:"abstract"`
@ -63,7 +65,11 @@ func FetchPapers(startDate string, endDate string, query string, maxPapers int)
params.Add("search_query", searchQuery)
params.Add("max_results", fmt.Sprintf("%d", maxPapers))
resp, err := http.Get(baseURL + "?" + params.Encode())
req, err := http.NewRequest("GET", baseURL+"?"+params.Encode(), nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %v", err)
}
resp, err := httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to fetch from arXiv: %v", err)
}
@ -93,6 +99,10 @@ func FetchPapers(startDate string, endDate string, query string, maxPapers int)
}
func SaveToFile(papers []Paper, startDate string, endDate string, query string) error {
if len(papers) == 0 {
return fmt.Errorf("no papers to save")
}
filename := fmt.Sprintf("%s-%s-%s.json", startDate, endDate, SanitizeFilename(query))
data, err := json.MarshalIndent(papers, "", " ")

173
arxiva_test.go Normal file
View File

@ -0,0 +1,173 @@
package arxiva
import (
"bytes"
"encoding/json"
"errors"
"io"
"net/http"
"os"
"testing"
)
// Mock HTTP client that implements the http.RoundTripper interface
type mockTransport struct {
response *http.Response
err error
}
func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return m.response, m.err
}
func TestFetchPapers(t *testing.T) {
tests := []struct {
name string
mockResponse string
mockError error
expectError bool
}{
{
name: "successful fetch",
mockResponse: `<?xml version="1.0" encoding="UTF-8"?>
<feed xmlns="http://www.w3.org/2005/Atom">
<entry>
<title>Test Paper</title>
<summary>Test summary</summary>
<id>http://arxiv.org/abs/1234.56789</id>
</entry>
</feed>`,
expectError: false,
},
{
name: "api error",
mockError: errors.New("API unavailable"),
expectError: true,
},
{
name: "invalid json",
mockResponse: "invalid json",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Setup mock HTTP client
var client *http.Client
var response *http.Response
if tt.mockError == nil {
response = &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(tt.mockResponse)),
}
}
client = &http.Client{
Transport: &mockTransport{
response: response,
err: tt.mockError,
},
}
// Replace default client with mock
originalClient := httpClient
httpClient = client
defer func() { httpClient = originalClient }()
// Test function
papers, err := FetchPapers("20240101", "20240125", "test", 1)
if err != nil {
t.Logf("FetchPapers error: %v", err)
}
if (err != nil) != tt.expectError {
t.Errorf("FetchPapers() error = %v, expectError %v", err, tt.expectError)
return
}
if !tt.expectError && len(papers) == 0 {
t.Error("Expected papers but got none")
}
})
}
}
func TestSaveToFile(t *testing.T) {
tests := []struct {
name string
papers []Paper
expectError bool
}{
{
name: "successful save",
papers: []Paper{
{
Title: "Test Paper",
Abstract: "Test abstract",
ArxivID: "1234.56789",
},
},
expectError: false,
},
{
name: "empty papers",
papers: []Paper{},
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create temp file
tmpFile, err := os.CreateTemp("", "test-*.json")
if err != nil {
t.Fatal(err)
}
defer os.Remove(tmpFile.Name())
// Test function
filename := "20240101-20240125-test.json"
err = SaveToFile(tt.papers, "20240101", "20240125", "test")
if err != nil && !tt.expectError {
t.Errorf("SaveToFile() unexpected error: %v", err)
return
}
if err == nil && tt.expectError {
t.Errorf("SaveToFile() expected error but got none")
return
}
if tt.expectError {
return
}
// Clean up test file
defer os.Remove(filename)
if (err != nil) != tt.expectError {
t.Errorf("SaveToFile() error = %v, expectError %v", err, tt.expectError)
return
}
// Verify file contents if successful
if !tt.expectError {
var savedPapers []Paper
fileData, err := os.ReadFile(filename)
if err != nil {
t.Errorf("Failed to read saved file: %v", err)
return
}
if err := json.Unmarshal(fileData, &savedPapers); err != nil {
t.Errorf("Failed to unmarshal saved file: %v", err)
return
}
if len(savedPapers) != len(tt.papers) {
t.Errorf("Expected %d papers, got %d", len(tt.papers), len(savedPapers))
}
} else if tt.name == "empty papers" {
if _, err := os.Stat(filename); !os.IsNotExist(err) {
t.Errorf("File should not exist for empty papers case")
}
}
})
}
}