arxiva/arxiva_test.go

174 lines
3.8 KiB
Go

package arxiva
import (
"bytes"
"encoding/json"
"errors"
"io"
"net/http"
"os"
"testing"
)
// Mock HTTP client that implements the http.RoundTripper interface
type mockTransport struct {
response *http.Response
err error
}
func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return m.response, m.err
}
func TestFetchPapers(t *testing.T) {
tests := []struct {
name string
mockResponse string
mockError error
expectError bool
}{
{
name: "successful fetch",
mockResponse: `<?xml version="1.0" encoding="UTF-8"?>
<feed xmlns="http://www.w3.org/2005/Atom">
<entry>
<title>Test Paper</title>
<summary>Test summary</summary>
<id>http://arxiv.org/abs/1234.56789</id>
</entry>
</feed>`,
expectError: false,
},
{
name: "api error",
mockError: errors.New("API unavailable"),
expectError: true,
},
{
name: "invalid json",
mockResponse: "invalid json",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Setup mock HTTP client
var client *http.Client
var response *http.Response
if tt.mockError == nil {
response = &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(tt.mockResponse)),
}
}
client = &http.Client{
Transport: &mockTransport{
response: response,
err: tt.mockError,
},
}
// Replace default client with mock
originalClient := httpClient
httpClient = client
defer func() { httpClient = originalClient }()
// Test function
papers, err := FetchPapers("20240101", "20240125", "test", 1)
if err != nil {
t.Logf("FetchPapers error: %v", err)
}
if (err != nil) != tt.expectError {
t.Errorf("FetchPapers() error = %v, expectError %v", err, tt.expectError)
return
}
if !tt.expectError && len(papers) == 0 {
t.Error("Expected papers but got none")
}
})
}
}
func TestSaveToFile(t *testing.T) {
tests := []struct {
name string
papers []Paper
expectError bool
}{
{
name: "successful save",
papers: []Paper{
{
Title: "Test Paper",
Abstract: "Test abstract",
ArxivID: "1234.56789",
},
},
expectError: false,
},
{
name: "empty papers",
papers: []Paper{},
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create temp file
tmpFile, err := os.CreateTemp("", "test-*.json")
if err != nil {
t.Fatal(err)
}
defer os.Remove(tmpFile.Name())
// Test function
filename := "20240101-20240125-test.json"
err = SaveToFile(tt.papers, "20240101", "20240125", "test")
if err != nil && !tt.expectError {
t.Errorf("SaveToFile() unexpected error: %v", err)
return
}
if err == nil && tt.expectError {
t.Errorf("SaveToFile() expected error but got none")
return
}
if tt.expectError {
return
}
// Clean up test file
defer os.Remove(filename)
if (err != nil) != tt.expectError {
t.Errorf("SaveToFile() error = %v, expectError %v", err, tt.expectError)
return
}
// Verify file contents if successful
if !tt.expectError {
var savedPapers []Paper
fileData, err := os.ReadFile(filename)
if err != nil {
t.Errorf("Failed to read saved file: %v", err)
return
}
if err := json.Unmarshal(fileData, &savedPapers); err != nil {
t.Errorf("Failed to unmarshal saved file: %v", err)
return
}
if len(savedPapers) != len(tt.papers) {
t.Errorf("Expected %d papers, got %d", len(tt.papers), len(savedPapers))
}
} else if tt.name == "empty papers" {
if _, err := os.Stat(filename); !os.IsNotExist(err) {
t.Errorf("File should not exist for empty papers case")
}
}
})
}
}