ira/sim-search-api/tests/test_api.py

481 lines
16 KiB
Python

"""
Test script for the sim-search API.
This script tests the core functionality of the API, including authentication,
query processing, search execution, and report generation.
"""
import os
import sys
import asyncio
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
# Add the project root directory to the Python path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from app.main import app
from app.db.session import Base
from app.db.models import User
from app.core.security import get_password_hash
from app.core.config import settings
from app.api.dependencies import get_db
# Create a test database
TEST_SQLALCHEMY_DATABASE_URI = "sqlite:///./test.db"
engine = create_engine(
TEST_SQLALCHEMY_DATABASE_URI,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Override the get_db dependency
def override_get_db():
try:
db = TestingSessionLocal()
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_get_db
# Create a test client
client = TestClient(app)
# Test user credentials
test_user_email = "test@example.com"
test_user_password = "password123"
test_user_full_name = "Test User"
@pytest.fixture(scope="module")
def setup_database():
"""Set up the test database."""
# Clean up any existing database
if os.path.exists("./test.db"):
os.remove("./test.db")
# Create tables
Base.metadata.create_all(bind=engine)
# Create a test user
db = TestingSessionLocal()
user = User(
email=test_user_email,
hashed_password=get_password_hash(test_user_password),
full_name=test_user_full_name,
is_active=True,
is_superuser=False,
)
db.add(user)
try:
db.commit()
db.refresh(user)
except Exception as e:
db.rollback()
print(f"Error creating test user: {e}")
finally:
db.close()
yield
# Clean up
Base.metadata.drop_all(bind=engine)
if os.path.exists("./test.db"):
os.remove("./test.db")
@pytest.fixture(scope="module")
def auth_token(setup_database):
"""Get an authentication token for the test user."""
response = client.post(
f"{settings.API_V1_STR}/auth/token",
data={"username": test_user_email, "password": test_user_password},
)
assert response.status_code == 200
token_data = response.json()
assert "access_token" in token_data
assert token_data["token_type"] == "bearer"
return token_data["access_token"]
def test_root():
"""Test the root endpoint."""
response = client.get("/")
assert response.status_code == 200
data = response.json()
assert data["status"] == "online"
assert data["version"] == settings.VERSION
assert data["project"] == settings.PROJECT_NAME
assert data["docs"] == "/docs"
def test_auth_token(setup_database):
"""Test getting an authentication token."""
response = client.post(
f"{settings.API_V1_STR}/auth/token",
data={"username": test_user_email, "password": test_user_password},
)
assert response.status_code == 200
token_data = response.json()
assert "access_token" in token_data
assert token_data["token_type"] == "bearer"
def test_auth_token_invalid_credentials(setup_database):
"""Test getting an authentication token with invalid credentials."""
response = client.post(
f"{settings.API_V1_STR}/auth/token",
data={"username": test_user_email, "password": "wrong_password"},
)
assert response.status_code == 401
assert response.json()["detail"] == "Incorrect email or password"
def test_register_user(setup_database):
"""Test registering a new user."""
response = client.post(
f"{settings.API_V1_STR}/auth/register",
json={
"email": "new_user@example.com",
"password": "password123",
"full_name": "New User",
"is_active": True,
"is_superuser": False,
},
)
assert response.status_code == 200
user_data = response.json()
assert user_data["email"] == "new_user@example.com"
assert user_data["full_name"] == "New User"
assert user_data["is_active"] == True
assert user_data["is_superuser"] == False
def test_register_existing_user(setup_database):
"""Test registering a user with an existing email."""
response = client.post(
f"{settings.API_V1_STR}/auth/register",
json={
"email": test_user_email,
"password": "password123",
"full_name": "Duplicate User",
"is_active": True,
"is_superuser": False,
},
)
assert response.status_code == 400
assert response.json()["detail"] == "A user with this email already exists"
def test_process_query(auth_token):
"""Test processing a query."""
response = client.post(
f"{settings.API_V1_STR}/query/process",
json={"query": "What are the environmental impacts of electric vehicles?"},
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == 200
data = response.json()
assert data["original_query"] == "What are the environmental impacts of electric vehicles?"
assert "structured_query" in data
assert data["structured_query"]["original_query"] == "What are the environmental impacts of electric vehicles?"
def test_classify_query(auth_token):
"""Test classifying a query."""
response = client.post(
f"{settings.API_V1_STR}/query/classify",
json={"query": "What are the environmental impacts of electric vehicles?"},
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == 200
data = response.json()
assert data["original_query"] == "What are the environmental impacts of electric vehicles?"
assert "structured_query" in data
assert data["structured_query"]["original_query"] == "What are the environmental impacts of electric vehicles?"
assert "type" in data["structured_query"]
assert "domain" in data["structured_query"]
def test_get_available_search_engines(auth_token):
"""Test getting available search engines."""
response = client.get(
f"{settings.API_V1_STR}/search/engines",
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == 200
engines = response.json()
assert isinstance(engines, list)
assert len(engines) > 0
def test_execute_search(auth_token):
"""Test executing a search."""
response = client.post(
f"{settings.API_V1_STR}/search/execute",
json={
"structured_query": {
"original_query": "What are the environmental impacts of electric vehicles?",
"enhanced_query": "What are the environmental impacts of electric vehicles?",
"type": "factual",
"domain": "environmental",
},
"search_engines": ["google", "arxiv"],
"num_results": 5,
"timeout": 30,
},
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == 200
data = response.json()
assert "search_id" in data
assert data["query"] == "What are the environmental impacts of electric vehicles?"
assert "results" in data
assert "total_results" in data
assert "execution_time" in data
def test_get_search_history(auth_token):
"""Test getting search history."""
response = client.get(
f"{settings.API_V1_STR}/search/history",
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == 200
data = response.json()
assert "searches" in data
assert "total" in data
assert isinstance(data["searches"], list)
assert isinstance(data["total"], int)
def test_get_search_results(auth_token):
"""Test getting search results."""
# First, execute a search to get a search_id
response = client.post(
f"{settings.API_V1_STR}/search/execute",
json={
"structured_query": {
"original_query": "What are the economic benefits of electric vehicles?",
"enhanced_query": "What are the economic benefits of electric vehicles?",
"type": "factual",
"domain": "economic",
},
"search_engines": ["google", "arxiv"],
"num_results": 5,
"timeout": 30,
},
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == 200
search_data = response.json()
search_id = search_data["search_id"]
# Now get the search results
response = client.get(
f"{settings.API_V1_STR}/search/{search_id}",
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == 200
data = response.json()
assert data["search_id"] == search_id
assert data["query"] == "What are the economic benefits of electric vehicles?"
assert "results" in data
assert "total_results" in data
def test_generate_report(auth_token):
"""Test generating a report."""
# First, execute a search to get a search_id
response = client.post(
f"{settings.API_V1_STR}/search/execute",
json={
"structured_query": {
"original_query": "What are the environmental and economic impacts of electric vehicles?",
"enhanced_query": "What are the environmental and economic impacts of electric vehicles?",
"type": "comparative",
"domain": "environmental,economic",
},
"search_engines": ["google", "arxiv"],
"num_results": 5,
"timeout": 30,
},
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == 200
search_data = response.json()
search_id = search_data["search_id"]
# Now generate a report
response = client.post(
f"{settings.API_V1_STR}/report/generate",
json={
"search_id": search_id,
"query": "What are the environmental and economic impacts of electric vehicles?",
"detail_level": "standard",
"query_type": "comparative",
"model": "llama-3.1-8b-instant",
},
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == 200
data = response.json()
assert "id" in data
assert data["title"].startswith("Report: What are the environmental and economic impacts")
assert data["detail_level"] == "standard"
assert data["query_type"] == "comparative"
assert data["model_used"] == "llama-3.1-8b-instant"
# Get the report progress
report_id = data["id"]
response = client.get(
f"{settings.API_V1_STR}/report/{report_id}/progress",
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == 200
progress_data = response.json()
assert progress_data["report_id"] == report_id
assert "progress" in progress_data
assert "status" in progress_data
def test_get_report_list(auth_token):
"""Test getting a list of reports."""
response = client.get(
f"{settings.API_V1_STR}/report/list",
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == 200
data = response.json()
assert "reports" in data
assert "total" in data
assert isinstance(data["reports"], list)
assert isinstance(data["total"], int)
def test_get_report(auth_token):
"""Test getting a specific report."""
# First, get the list of reports to get a report_id
response = client.get(
f"{settings.API_V1_STR}/report/list",
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == 200
list_data = response.json()
assert len(list_data["reports"]) > 0
report_id = list_data["reports"][0]["id"]
# Now get the specific report
response = client.get(
f"{settings.API_V1_STR}/report/{report_id}",
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == 200
data = response.json()
assert data["id"] == report_id
assert "title" in data
assert "content" in data
assert "detail_level" in data
assert "query_type" in data
assert "model_used" in data
def test_download_report(auth_token):
"""Test downloading a report."""
# First, execute a search to get a search_id
response = client.post(
f"{settings.API_V1_STR}/search/execute",
json={
"structured_query": {
"original_query": "What are the environmental impacts of electric vehicles?",
"enhanced_query": "What are the environmental impacts of electric vehicles?",
"type": "comparative",
"domain": "environmental,economic",
},
"search_engines": ["serper"],
"num_results": 2,
"timeout": 10,
},
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == 200
search_data = response.json()
search_id = search_data["search_id"]
# Now generate a report
response = client.post(
f"{settings.API_V1_STR}/report/generate",
json={
"search_id": search_id,
"query": "What are the environmental impacts of electric vehicles?",
"detail_level": "brief",
"query_type": "comparative",
"model": "llama-3.1-8b-instant",
},
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == 200
report_data = response.json()
report_id = report_data["id"]
# Now download the report in markdown format
response = client.get(
f"{settings.API_V1_STR}/report/{report_id}/download?format=markdown",
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == 200
assert response.headers["content-type"] == "application/octet-stream"
assert response.headers["content-disposition"] == f'attachment; filename="report_{report_id}.markdown"'
# Now download the report in HTML format
response = client.get(
f"{settings.API_V1_STR}/report/{report_id}/download?format=html",
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == 200
assert response.headers["content-type"] == "application/octet-stream"
assert response.headers["content-disposition"] == f'attachment; filename="report_{report_id}.html"'
def test_delete_report(auth_token):
"""Test deleting a report."""
# First, get the list of reports to get a report_id
response = client.get(
f"{settings.API_V1_STR}/report/list",
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == 200
list_data = response.json()
assert len(list_data["reports"]) > 0
report_id = list_data["reports"][0]["id"]
# Now delete the report
response = client.delete(
f"{settings.API_V1_STR}/report/{report_id}",
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == 204
# Verify that the report is deleted
response = client.get(
f"{settings.API_V1_STR}/report/{report_id}",
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == 404
def test_delete_search(auth_token):
"""Test deleting a search."""
# First, get the list of searches to get a search_id
response = client.get(
f"{settings.API_V1_STR}/search/history",
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == 200
list_data = response.json()
assert len(list_data["searches"]) > 0
search_id = list_data["searches"][0]["id"]
# Now delete the search
response = client.delete(
f"{settings.API_V1_STR}/search/{search_id}",
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == 204
# Verify that the search is deleted
response = client.get(
f"{settings.API_V1_STR}/search/{search_id}",
headers={"Authorization": f"Bearer {auth_token}"},
)
assert response.status_code == 404
if __name__ == "__main__":
pytest.main(["-xvs", __file__])