174 lines
3.8 KiB
Go
174 lines
3.8 KiB
Go
package arxiva
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"errors"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"testing"
|
|
)
|
|
|
|
// Mock HTTP client that implements the http.RoundTripper interface
|
|
type mockTransport struct {
|
|
response *http.Response
|
|
err error
|
|
}
|
|
|
|
func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
return m.response, m.err
|
|
}
|
|
|
|
func TestFetchPapers(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
mockResponse string
|
|
mockError error
|
|
expectError bool
|
|
}{
|
|
{
|
|
name: "successful fetch",
|
|
mockResponse: `<?xml version="1.0" encoding="UTF-8"?>
|
|
<feed xmlns="http://www.w3.org/2005/Atom">
|
|
<entry>
|
|
<title>Test Paper</title>
|
|
<summary>Test summary</summary>
|
|
<id>http://arxiv.org/abs/1234.56789</id>
|
|
</entry>
|
|
</feed>`,
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "api error",
|
|
mockError: errors.New("API unavailable"),
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "invalid json",
|
|
mockResponse: "invalid json",
|
|
expectError: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Setup mock HTTP client
|
|
var client *http.Client
|
|
var response *http.Response
|
|
if tt.mockError == nil {
|
|
response = &http.Response{
|
|
StatusCode: http.StatusOK,
|
|
Body: io.NopCloser(bytes.NewBufferString(tt.mockResponse)),
|
|
}
|
|
}
|
|
client = &http.Client{
|
|
Transport: &mockTransport{
|
|
response: response,
|
|
err: tt.mockError,
|
|
},
|
|
}
|
|
|
|
// Replace default client with mock
|
|
originalClient := httpClient
|
|
httpClient = client
|
|
defer func() { httpClient = originalClient }()
|
|
|
|
// Test function
|
|
papers, err := FetchPapers("20240101", "20240125", "test", 1)
|
|
if err != nil {
|
|
t.Logf("FetchPapers error: %v", err)
|
|
}
|
|
if (err != nil) != tt.expectError {
|
|
t.Errorf("FetchPapers() error = %v, expectError %v", err, tt.expectError)
|
|
return
|
|
}
|
|
|
|
if !tt.expectError && len(papers) == 0 {
|
|
t.Error("Expected papers but got none")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSaveToFile(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
papers []Paper
|
|
expectError bool
|
|
}{
|
|
{
|
|
name: "successful save",
|
|
papers: []Paper{
|
|
{
|
|
Title: "Test Paper",
|
|
Abstract: "Test abstract",
|
|
ArxivID: "1234.56789",
|
|
},
|
|
},
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "empty papers",
|
|
papers: []Paper{},
|
|
expectError: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Create temp file
|
|
tmpFile, err := os.CreateTemp("", "test-*.json")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer os.Remove(tmpFile.Name())
|
|
|
|
// Test function
|
|
filename := "20240101-20240125-test.json"
|
|
err = SaveToFile(tt.papers, "20240101", "20240125", "test")
|
|
if err != nil && !tt.expectError {
|
|
t.Errorf("SaveToFile() unexpected error: %v", err)
|
|
return
|
|
}
|
|
if err == nil && tt.expectError {
|
|
t.Errorf("SaveToFile() expected error but got none")
|
|
return
|
|
}
|
|
if tt.expectError {
|
|
return
|
|
}
|
|
|
|
// Clean up test file
|
|
defer os.Remove(filename)
|
|
if (err != nil) != tt.expectError {
|
|
t.Errorf("SaveToFile() error = %v, expectError %v", err, tt.expectError)
|
|
return
|
|
}
|
|
|
|
// Verify file contents if successful
|
|
if !tt.expectError {
|
|
var savedPapers []Paper
|
|
fileData, err := os.ReadFile(filename)
|
|
if err != nil {
|
|
t.Errorf("Failed to read saved file: %v", err)
|
|
return
|
|
}
|
|
|
|
if err := json.Unmarshal(fileData, &savedPapers); err != nil {
|
|
t.Errorf("Failed to unmarshal saved file: %v", err)
|
|
return
|
|
}
|
|
|
|
if len(savedPapers) != len(tt.papers) {
|
|
t.Errorf("Expected %d papers, got %d", len(tt.papers), len(savedPapers))
|
|
}
|
|
} else if tt.name == "empty papers" {
|
|
if _, err := os.Stat(filename); !os.IsNotExist(err) {
|
|
t.Errorf("File should not exist for empty papers case")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|