Add unit tests
This commit is contained in:
parent
d2190124f9
commit
6f7eea0373
|
@ -0,0 +1 @@
|
|||
*.json
|
12
arxiva.go
12
arxiva.go
|
@ -11,6 +11,8 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
var httpClient = &http.Client{}
|
||||
|
||||
type Paper struct {
|
||||
Title string `json:"title"`
|
||||
Abstract string `json:"abstract"`
|
||||
|
@ -63,7 +65,11 @@ func FetchPapers(startDate string, endDate string, query string, maxPapers int)
|
|||
params.Add("search_query", searchQuery)
|
||||
params.Add("max_results", fmt.Sprintf("%d", maxPapers))
|
||||
|
||||
resp, err := http.Get(baseURL + "?" + params.Encode())
|
||||
req, err := http.NewRequest("GET", baseURL+"?"+params.Encode(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %v", err)
|
||||
}
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch from arXiv: %v", err)
|
||||
}
|
||||
|
@ -93,6 +99,10 @@ func FetchPapers(startDate string, endDate string, query string, maxPapers int)
|
|||
}
|
||||
|
||||
func SaveToFile(papers []Paper, startDate string, endDate string, query string) error {
|
||||
if len(papers) == 0 {
|
||||
return fmt.Errorf("no papers to save")
|
||||
}
|
||||
|
||||
filename := fmt.Sprintf("%s-%s-%s.json", startDate, endDate, SanitizeFilename(query))
|
||||
|
||||
data, err := json.MarshalIndent(papers, "", " ")
|
||||
|
|
|
@ -0,0 +1,173 @@
|
|||
package arxiva
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Mock HTTP client that implements the http.RoundTripper interface
|
||||
type mockTransport struct {
|
||||
response *http.Response
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return m.response, m.err
|
||||
}
|
||||
|
||||
func TestFetchPapers(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mockResponse string
|
||||
mockError error
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "successful fetch",
|
||||
mockResponse: `<?xml version="1.0" encoding="UTF-8"?>
|
||||
<feed xmlns="http://www.w3.org/2005/Atom">
|
||||
<entry>
|
||||
<title>Test Paper</title>
|
||||
<summary>Test summary</summary>
|
||||
<id>http://arxiv.org/abs/1234.56789</id>
|
||||
</entry>
|
||||
</feed>`,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "api error",
|
||||
mockError: errors.New("API unavailable"),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid json",
|
||||
mockResponse: "invalid json",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Setup mock HTTP client
|
||||
var client *http.Client
|
||||
var response *http.Response
|
||||
if tt.mockError == nil {
|
||||
response = &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(bytes.NewBufferString(tt.mockResponse)),
|
||||
}
|
||||
}
|
||||
client = &http.Client{
|
||||
Transport: &mockTransport{
|
||||
response: response,
|
||||
err: tt.mockError,
|
||||
},
|
||||
}
|
||||
|
||||
// Replace default client with mock
|
||||
originalClient := httpClient
|
||||
httpClient = client
|
||||
defer func() { httpClient = originalClient }()
|
||||
|
||||
// Test function
|
||||
papers, err := FetchPapers("20240101", "20240125", "test", 1)
|
||||
if err != nil {
|
||||
t.Logf("FetchPapers error: %v", err)
|
||||
}
|
||||
if (err != nil) != tt.expectError {
|
||||
t.Errorf("FetchPapers() error = %v, expectError %v", err, tt.expectError)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.expectError && len(papers) == 0 {
|
||||
t.Error("Expected papers but got none")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveToFile(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
papers []Paper
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "successful save",
|
||||
papers: []Paper{
|
||||
{
|
||||
Title: "Test Paper",
|
||||
Abstract: "Test abstract",
|
||||
ArxivID: "1234.56789",
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "empty papers",
|
||||
papers: []Paper{},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create temp file
|
||||
tmpFile, err := os.CreateTemp("", "test-*.json")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
// Test function
|
||||
filename := "20240101-20240125-test.json"
|
||||
err = SaveToFile(tt.papers, "20240101", "20240125", "test")
|
||||
if err != nil && !tt.expectError {
|
||||
t.Errorf("SaveToFile() unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
if err == nil && tt.expectError {
|
||||
t.Errorf("SaveToFile() expected error but got none")
|
||||
return
|
||||
}
|
||||
if tt.expectError {
|
||||
return
|
||||
}
|
||||
|
||||
// Clean up test file
|
||||
defer os.Remove(filename)
|
||||
if (err != nil) != tt.expectError {
|
||||
t.Errorf("SaveToFile() error = %v, expectError %v", err, tt.expectError)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify file contents if successful
|
||||
if !tt.expectError {
|
||||
var savedPapers []Paper
|
||||
fileData, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to read saved file: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(fileData, &savedPapers); err != nil {
|
||||
t.Errorf("Failed to unmarshal saved file: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(savedPapers) != len(tt.papers) {
|
||||
t.Errorf("Expected %d papers, got %d", len(tt.papers), len(savedPapers))
|
||||
}
|
||||
} else if tt.name == "empty papers" {
|
||||
if _, err := os.Stat(filename); !os.IsNotExist(err) {
|
||||
t.Errorf("File should not exist for empty papers case")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue