443 lines
15 KiB
Python
443 lines
15 KiB
Python
"""
|
|
Test script for the sim-search API.
|
|
|
|
This script tests the core functionality of the API, including authentication,
|
|
query processing, search execution, and report generation.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import asyncio
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
from sqlalchemy.pool import StaticPool
|
|
|
|
# Add the project root directory to the Python path
|
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
|
|
|
from app.main import app
|
|
from app.db.session import Base
|
|
from app.db.models import User
|
|
from app.core.security import get_password_hash
|
|
from app.core.config import settings
|
|
from app.api.dependencies import get_db
|
|
|
|
# Create a test database
|
|
TEST_SQLALCHEMY_DATABASE_URI = "sqlite:///./test.db"
|
|
engine = create_engine(
|
|
TEST_SQLALCHEMY_DATABASE_URI,
|
|
connect_args={"check_same_thread": False},
|
|
poolclass=StaticPool,
|
|
)
|
|
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
|
|
# Override the get_db dependency
|
|
def override_get_db():
|
|
try:
|
|
db = TestingSessionLocal()
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
|
|
# Create a test client
|
|
client = TestClient(app)
|
|
|
|
# Test user credentials
|
|
test_user_email = "test@example.com"
|
|
test_user_password = "password123"
|
|
test_user_full_name = "Test User"
|
|
|
|
@pytest.fixture(scope="module")
|
|
def setup_database():
|
|
"""Set up the test database."""
|
|
# Create tables
|
|
Base.metadata.create_all(bind=engine)
|
|
|
|
# Create a test user
|
|
db = TestingSessionLocal()
|
|
user = User(
|
|
email=test_user_email,
|
|
hashed_password=get_password_hash(test_user_password),
|
|
full_name=test_user_full_name,
|
|
is_active=True,
|
|
is_superuser=False,
|
|
)
|
|
db.add(user)
|
|
db.commit()
|
|
db.refresh(user)
|
|
|
|
yield
|
|
|
|
# Clean up
|
|
Base.metadata.drop_all(bind=engine)
|
|
|
|
@pytest.fixture(scope="module")
|
|
def auth_token(setup_database):
|
|
"""Get an authentication token for the test user."""
|
|
response = client.post(
|
|
f"{settings.API_V1_STR}/auth/token",
|
|
data={"username": test_user_email, "password": test_user_password},
|
|
)
|
|
assert response.status_code == 200
|
|
token_data = response.json()
|
|
assert "access_token" in token_data
|
|
assert token_data["token_type"] == "bearer"
|
|
return token_data["access_token"]
|
|
|
|
def test_root():
|
|
"""Test the root endpoint."""
|
|
response = client.get("/")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "online"
|
|
assert data["version"] == settings.VERSION
|
|
assert data["project"] == settings.PROJECT_NAME
|
|
assert data["docs"] == "/docs"
|
|
|
|
def test_auth_token(setup_database):
|
|
"""Test getting an authentication token."""
|
|
response = client.post(
|
|
f"{settings.API_V1_STR}/auth/token",
|
|
data={"username": test_user_email, "password": test_user_password},
|
|
)
|
|
assert response.status_code == 200
|
|
token_data = response.json()
|
|
assert "access_token" in token_data
|
|
assert token_data["token_type"] == "bearer"
|
|
|
|
def test_auth_token_invalid_credentials(setup_database):
|
|
"""Test getting an authentication token with invalid credentials."""
|
|
response = client.post(
|
|
f"{settings.API_V1_STR}/auth/token",
|
|
data={"username": test_user_email, "password": "wrong_password"},
|
|
)
|
|
assert response.status_code == 401
|
|
assert response.json()["detail"] == "Incorrect email or password"
|
|
|
|
def test_register_user(setup_database):
|
|
"""Test registering a new user."""
|
|
response = client.post(
|
|
f"{settings.API_V1_STR}/auth/register",
|
|
json={
|
|
"email": "new_user@example.com",
|
|
"password": "password123",
|
|
"full_name": "New User",
|
|
"is_active": True,
|
|
"is_superuser": False,
|
|
},
|
|
)
|
|
assert response.status_code == 200
|
|
user_data = response.json()
|
|
assert user_data["email"] == "new_user@example.com"
|
|
assert user_data["full_name"] == "New User"
|
|
assert user_data["is_active"] == True
|
|
assert user_data["is_superuser"] == False
|
|
|
|
def test_register_existing_user(setup_database):
|
|
"""Test registering a user with an existing email."""
|
|
response = client.post(
|
|
f"{settings.API_V1_STR}/auth/register",
|
|
json={
|
|
"email": test_user_email,
|
|
"password": "password123",
|
|
"full_name": "Duplicate User",
|
|
"is_active": True,
|
|
"is_superuser": False,
|
|
},
|
|
)
|
|
assert response.status_code == 400
|
|
assert response.json()["detail"] == "A user with this email already exists"
|
|
|
|
def test_process_query(auth_token):
|
|
"""Test processing a query."""
|
|
response = client.post(
|
|
f"{settings.API_V1_STR}/query/process",
|
|
json={"query": "What are the environmental impacts of electric vehicles?"},
|
|
headers={"Authorization": f"Bearer {auth_token}"},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["original_query"] == "What are the environmental impacts of electric vehicles?"
|
|
assert "structured_query" in data
|
|
assert data["structured_query"]["original_query"] == "What are the environmental impacts of electric vehicles?"
|
|
|
|
def test_classify_query(auth_token):
|
|
"""Test classifying a query."""
|
|
response = client.post(
|
|
f"{settings.API_V1_STR}/query/classify",
|
|
json={"query": "What are the environmental impacts of electric vehicles?"},
|
|
headers={"Authorization": f"Bearer {auth_token}"},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["original_query"] == "What are the environmental impacts of electric vehicles?"
|
|
assert "structured_query" in data
|
|
assert data["structured_query"]["original_query"] == "What are the environmental impacts of electric vehicles?"
|
|
assert "type" in data["structured_query"]
|
|
assert "domain" in data["structured_query"]
|
|
|
|
def test_get_available_search_engines(auth_token):
|
|
"""Test getting available search engines."""
|
|
response = client.get(
|
|
f"{settings.API_V1_STR}/search/engines",
|
|
headers={"Authorization": f"Bearer {auth_token}"},
|
|
)
|
|
assert response.status_code == 200
|
|
engines = response.json()
|
|
assert isinstance(engines, list)
|
|
assert len(engines) > 0
|
|
|
|
def test_execute_search(auth_token):
|
|
"""Test executing a search."""
|
|
response = client.post(
|
|
f"{settings.API_V1_STR}/search/execute",
|
|
json={
|
|
"structured_query": {
|
|
"original_query": "What are the environmental impacts of electric vehicles?",
|
|
"enhanced_query": "What are the environmental impacts of electric vehicles?",
|
|
"type": "factual",
|
|
"domain": "environmental",
|
|
},
|
|
"search_engines": ["google", "arxiv"],
|
|
"num_results": 5,
|
|
"timeout": 30,
|
|
},
|
|
headers={"Authorization": f"Bearer {auth_token}"},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "search_id" in data
|
|
assert data["query"] == "What are the environmental impacts of electric vehicles?"
|
|
assert "results" in data
|
|
assert "total_results" in data
|
|
assert "execution_time" in data
|
|
|
|
def test_get_search_history(auth_token):
|
|
"""Test getting search history."""
|
|
response = client.get(
|
|
f"{settings.API_V1_STR}/search/history",
|
|
headers={"Authorization": f"Bearer {auth_token}"},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "searches" in data
|
|
assert "total" in data
|
|
assert isinstance(data["searches"], list)
|
|
assert isinstance(data["total"], int)
|
|
|
|
def test_get_search_results(auth_token):
|
|
"""Test getting search results."""
|
|
# First, execute a search to get a search_id
|
|
response = client.post(
|
|
f"{settings.API_V1_STR}/search/execute",
|
|
json={
|
|
"structured_query": {
|
|
"original_query": "What are the economic benefits of electric vehicles?",
|
|
"enhanced_query": "What are the economic benefits of electric vehicles?",
|
|
"type": "factual",
|
|
"domain": "economic",
|
|
},
|
|
"search_engines": ["google", "arxiv"],
|
|
"num_results": 5,
|
|
"timeout": 30,
|
|
},
|
|
headers={"Authorization": f"Bearer {auth_token}"},
|
|
)
|
|
assert response.status_code == 200
|
|
search_data = response.json()
|
|
search_id = search_data["search_id"]
|
|
|
|
# Now get the search results
|
|
response = client.get(
|
|
f"{settings.API_V1_STR}/search/{search_id}",
|
|
headers={"Authorization": f"Bearer {auth_token}"},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["search_id"] == search_id
|
|
assert data["query"] == "What are the economic benefits of electric vehicles?"
|
|
assert "results" in data
|
|
assert "total_results" in data
|
|
|
|
def test_generate_report(auth_token):
|
|
"""Test generating a report."""
|
|
# First, execute a search to get a search_id
|
|
response = client.post(
|
|
f"{settings.API_V1_STR}/search/execute",
|
|
json={
|
|
"structured_query": {
|
|
"original_query": "What are the environmental and economic impacts of electric vehicles?",
|
|
"enhanced_query": "What are the environmental and economic impacts of electric vehicles?",
|
|
"type": "comparative",
|
|
"domain": "environmental,economic",
|
|
},
|
|
"search_engines": ["google", "arxiv"],
|
|
"num_results": 5,
|
|
"timeout": 30,
|
|
},
|
|
headers={"Authorization": f"Bearer {auth_token}"},
|
|
)
|
|
assert response.status_code == 200
|
|
search_data = response.json()
|
|
search_id = search_data["search_id"]
|
|
|
|
# Now generate a report
|
|
response = client.post(
|
|
f"{settings.API_V1_STR}/report/generate",
|
|
json={
|
|
"search_id": search_id,
|
|
"query": "What are the environmental and economic impacts of electric vehicles?",
|
|
"detail_level": "standard",
|
|
"query_type": "comparative",
|
|
"model": "llama-3.1-8b-instant",
|
|
},
|
|
headers={"Authorization": f"Bearer {auth_token}"},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "id" in data
|
|
assert data["title"].startswith("Report: What are the environmental and economic impacts")
|
|
assert data["detail_level"] == "standard"
|
|
assert data["query_type"] == "comparative"
|
|
assert data["model_used"] == "llama-3.1-8b-instant"
|
|
|
|
# Get the report progress
|
|
report_id = data["id"]
|
|
response = client.get(
|
|
f"{settings.API_V1_STR}/report/{report_id}/progress",
|
|
headers={"Authorization": f"Bearer {auth_token}"},
|
|
)
|
|
assert response.status_code == 200
|
|
progress_data = response.json()
|
|
assert progress_data["report_id"] == report_id
|
|
assert "progress" in progress_data
|
|
assert "status" in progress_data
|
|
|
|
def test_get_report_list(auth_token):
|
|
"""Test getting a list of reports."""
|
|
response = client.get(
|
|
f"{settings.API_V1_STR}/report/list",
|
|
headers={"Authorization": f"Bearer {auth_token}"},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "reports" in data
|
|
assert "total" in data
|
|
assert isinstance(data["reports"], list)
|
|
assert isinstance(data["total"], int)
|
|
|
|
def test_get_report(auth_token):
|
|
"""Test getting a specific report."""
|
|
# First, get the list of reports to get a report_id
|
|
response = client.get(
|
|
f"{settings.API_V1_STR}/report/list",
|
|
headers={"Authorization": f"Bearer {auth_token}"},
|
|
)
|
|
assert response.status_code == 200
|
|
list_data = response.json()
|
|
assert len(list_data["reports"]) > 0
|
|
report_id = list_data["reports"][0]["id"]
|
|
|
|
# Now get the specific report
|
|
response = client.get(
|
|
f"{settings.API_V1_STR}/report/{report_id}",
|
|
headers={"Authorization": f"Bearer {auth_token}"},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["id"] == report_id
|
|
assert "title" in data
|
|
assert "content" in data
|
|
assert "detail_level" in data
|
|
assert "query_type" in data
|
|
assert "model_used" in data
|
|
|
|
def test_download_report(auth_token):
|
|
"""Test downloading a report."""
|
|
# First, get the list of reports to get a report_id
|
|
response = client.get(
|
|
f"{settings.API_V1_STR}/report/list",
|
|
headers={"Authorization": f"Bearer {auth_token}"},
|
|
)
|
|
assert response.status_code == 200
|
|
list_data = response.json()
|
|
assert len(list_data["reports"]) > 0
|
|
report_id = list_data["reports"][0]["id"]
|
|
|
|
# Now download the report in markdown format
|
|
response = client.get(
|
|
f"{settings.API_V1_STR}/report/{report_id}/download?format=markdown",
|
|
headers={"Authorization": f"Bearer {auth_token}"},
|
|
)
|
|
assert response.status_code == 200
|
|
assert response.headers["content-type"] == "application/octet-stream"
|
|
assert response.headers["content-disposition"] == f'filename="report_{report_id}.markdown"'
|
|
|
|
# Now download the report in HTML format
|
|
response = client.get(
|
|
f"{settings.API_V1_STR}/report/{report_id}/download?format=html",
|
|
headers={"Authorization": f"Bearer {auth_token}"},
|
|
)
|
|
assert response.status_code == 200
|
|
assert response.headers["content-type"] == "application/octet-stream"
|
|
assert response.headers["content-disposition"] == f'filename="report_{report_id}.html"'
|
|
|
|
def test_delete_report(auth_token):
|
|
"""Test deleting a report."""
|
|
# First, get the list of reports to get a report_id
|
|
response = client.get(
|
|
f"{settings.API_V1_STR}/report/list",
|
|
headers={"Authorization": f"Bearer {auth_token}"},
|
|
)
|
|
assert response.status_code == 200
|
|
list_data = response.json()
|
|
assert len(list_data["reports"]) > 0
|
|
report_id = list_data["reports"][0]["id"]
|
|
|
|
# Now delete the report
|
|
response = client.delete(
|
|
f"{settings.API_V1_STR}/report/{report_id}",
|
|
headers={"Authorization": f"Bearer {auth_token}"},
|
|
)
|
|
assert response.status_code == 204
|
|
|
|
# Verify that the report is deleted
|
|
response = client.get(
|
|
f"{settings.API_V1_STR}/report/{report_id}",
|
|
headers={"Authorization": f"Bearer {auth_token}"},
|
|
)
|
|
assert response.status_code == 404
|
|
|
|
def test_delete_search(auth_token):
|
|
"""Test deleting a search."""
|
|
# First, get the list of searches to get a search_id
|
|
response = client.get(
|
|
f"{settings.API_V1_STR}/search/history",
|
|
headers={"Authorization": f"Bearer {auth_token}"},
|
|
)
|
|
assert response.status_code == 200
|
|
list_data = response.json()
|
|
assert len(list_data["searches"]) > 0
|
|
search_id = list_data["searches"][0]["id"]
|
|
|
|
# Now delete the search
|
|
response = client.delete(
|
|
f"{settings.API_V1_STR}/search/{search_id}",
|
|
headers={"Authorization": f"Bearer {auth_token}"},
|
|
)
|
|
assert response.status_code == 204
|
|
|
|
# Verify that the search is deleted
|
|
response = client.get(
|
|
f"{settings.API_V1_STR}/search/{search_id}",
|
|
headers={"Authorization": f"Bearer {auth_token}"},
|
|
)
|
|
assert response.status_code == 404
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main(["-xvs", __file__])
|