""" Test script for the sim-search API. This script tests the core functionality of the API, including authentication, query processing, search execution, and report generation. """ import os import sys import asyncio import pytest from fastapi.testclient import TestClient from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import StaticPool # Add the project root directory to the Python path sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from app.main import app from app.db.session import Base from app.db.models import User from app.core.security import get_password_hash from app.core.config import settings from app.api.dependencies import get_db # Create a test database TEST_SQLALCHEMY_DATABASE_URI = "sqlite:///./test.db" engine = create_engine( TEST_SQLALCHEMY_DATABASE_URI, connect_args={"check_same_thread": False}, poolclass=StaticPool, ) TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) # Override the get_db dependency def override_get_db(): try: db = TestingSessionLocal() yield db finally: db.close() app.dependency_overrides[get_db] = override_get_db # Create a test client client = TestClient(app) # Test user credentials test_user_email = "test@example.com" test_user_password = "password123" test_user_full_name = "Test User" @pytest.fixture(scope="module") def setup_database(): """Set up the test database.""" # Clean up any existing database if os.path.exists("./test.db"): os.remove("./test.db") # Create tables Base.metadata.create_all(bind=engine) # Create a test user db = TestingSessionLocal() user = User( email=test_user_email, hashed_password=get_password_hash(test_user_password), full_name=test_user_full_name, is_active=True, is_superuser=False, ) db.add(user) try: db.commit() db.refresh(user) except Exception as e: db.rollback() print(f"Error creating test user: {e}") finally: db.close() yield # Clean up Base.metadata.drop_all(bind=engine) if os.path.exists("./test.db"): os.remove("./test.db") @pytest.fixture(scope="module") def auth_token(setup_database): """Get an authentication token for the test user.""" response = client.post( f"{settings.API_V1_STR}/auth/token", data={"username": test_user_email, "password": test_user_password}, ) assert response.status_code == 200 token_data = response.json() assert "access_token" in token_data assert token_data["token_type"] == "bearer" return token_data["access_token"] def test_root(): """Test the root endpoint.""" response = client.get("/") assert response.status_code == 200 data = response.json() assert data["status"] == "online" assert data["version"] == settings.VERSION assert data["project"] == settings.PROJECT_NAME assert data["docs"] == "/docs" def test_auth_token(setup_database): """Test getting an authentication token.""" response = client.post( f"{settings.API_V1_STR}/auth/token", data={"username": test_user_email, "password": test_user_password}, ) assert response.status_code == 200 token_data = response.json() assert "access_token" in token_data assert token_data["token_type"] == "bearer" def test_auth_token_invalid_credentials(setup_database): """Test getting an authentication token with invalid credentials.""" response = client.post( f"{settings.API_V1_STR}/auth/token", data={"username": test_user_email, "password": "wrong_password"}, ) assert response.status_code == 401 assert response.json()["detail"] == "Incorrect email or password" def test_register_user(setup_database): """Test registering a new user.""" response = client.post( f"{settings.API_V1_STR}/auth/register", json={ "email": "new_user@example.com", "password": "password123", "full_name": "New User", "is_active": True, "is_superuser": False, }, ) assert response.status_code == 200 user_data = response.json() assert user_data["email"] == "new_user@example.com" assert user_data["full_name"] == "New User" assert user_data["is_active"] == True assert user_data["is_superuser"] == False def test_register_existing_user(setup_database): """Test registering a user with an existing email.""" response = client.post( f"{settings.API_V1_STR}/auth/register", json={ "email": test_user_email, "password": "password123", "full_name": "Duplicate User", "is_active": True, "is_superuser": False, }, ) assert response.status_code == 400 assert response.json()["detail"] == "A user with this email already exists" def test_process_query(auth_token): """Test processing a query.""" response = client.post( f"{settings.API_V1_STR}/query/process", json={"query": "What are the environmental impacts of electric vehicles?"}, headers={"Authorization": f"Bearer {auth_token}"}, ) assert response.status_code == 200 data = response.json() assert data["original_query"] == "What are the environmental impacts of electric vehicles?" assert "structured_query" in data assert data["structured_query"]["original_query"] == "What are the environmental impacts of electric vehicles?" def test_classify_query(auth_token): """Test classifying a query.""" response = client.post( f"{settings.API_V1_STR}/query/classify", json={"query": "What are the environmental impacts of electric vehicles?"}, headers={"Authorization": f"Bearer {auth_token}"}, ) assert response.status_code == 200 data = response.json() assert data["original_query"] == "What are the environmental impacts of electric vehicles?" assert "structured_query" in data assert data["structured_query"]["original_query"] == "What are the environmental impacts of electric vehicles?" assert "type" in data["structured_query"] assert "domain" in data["structured_query"] def test_get_available_search_engines(auth_token): """Test getting available search engines.""" response = client.get( f"{settings.API_V1_STR}/search/engines", headers={"Authorization": f"Bearer {auth_token}"}, ) assert response.status_code == 200 engines = response.json() assert isinstance(engines, list) assert len(engines) > 0 def test_execute_search(auth_token): """Test executing a search.""" response = client.post( f"{settings.API_V1_STR}/search/execute", json={ "structured_query": { "original_query": "What are the environmental impacts of electric vehicles?", "enhanced_query": "What are the environmental impacts of electric vehicles?", "type": "factual", "domain": "environmental", }, "search_engines": ["google", "arxiv"], "num_results": 5, "timeout": 30, }, headers={"Authorization": f"Bearer {auth_token}"}, ) assert response.status_code == 200 data = response.json() assert "search_id" in data assert data["query"] == "What are the environmental impacts of electric vehicles?" assert "results" in data assert "total_results" in data assert "execution_time" in data def test_get_search_history(auth_token): """Test getting search history.""" response = client.get( f"{settings.API_V1_STR}/search/history", headers={"Authorization": f"Bearer {auth_token}"}, ) assert response.status_code == 200 data = response.json() assert "searches" in data assert "total" in data assert isinstance(data["searches"], list) assert isinstance(data["total"], int) def test_get_search_results(auth_token): """Test getting search results.""" # First, execute a search to get a search_id response = client.post( f"{settings.API_V1_STR}/search/execute", json={ "structured_query": { "original_query": "What are the economic benefits of electric vehicles?", "enhanced_query": "What are the economic benefits of electric vehicles?", "type": "factual", "domain": "economic", }, "search_engines": ["google", "arxiv"], "num_results": 5, "timeout": 30, }, headers={"Authorization": f"Bearer {auth_token}"}, ) assert response.status_code == 200 search_data = response.json() search_id = search_data["search_id"] # Now get the search results response = client.get( f"{settings.API_V1_STR}/search/{search_id}", headers={"Authorization": f"Bearer {auth_token}"}, ) assert response.status_code == 200 data = response.json() assert data["search_id"] == search_id assert data["query"] == "What are the economic benefits of electric vehicles?" assert "results" in data assert "total_results" in data def test_generate_report(auth_token): """Test generating a report.""" # First, execute a search to get a search_id response = client.post( f"{settings.API_V1_STR}/search/execute", json={ "structured_query": { "original_query": "What are the environmental and economic impacts of electric vehicles?", "enhanced_query": "What are the environmental and economic impacts of electric vehicles?", "type": "comparative", "domain": "environmental,economic", }, "search_engines": ["google", "arxiv"], "num_results": 5, "timeout": 30, }, headers={"Authorization": f"Bearer {auth_token}"}, ) assert response.status_code == 200 search_data = response.json() search_id = search_data["search_id"] # Now generate a report response = client.post( f"{settings.API_V1_STR}/report/generate", json={ "search_id": search_id, "query": "What are the environmental and economic impacts of electric vehicles?", "detail_level": "standard", "query_type": "comparative", "model": "llama-3.1-8b-instant", }, headers={"Authorization": f"Bearer {auth_token}"}, ) assert response.status_code == 200 data = response.json() assert "id" in data assert data["title"].startswith("Report: What are the environmental and economic impacts") assert data["detail_level"] == "standard" assert data["query_type"] == "comparative" assert data["model_used"] == "llama-3.1-8b-instant" # Get the report progress report_id = data["id"] response = client.get( f"{settings.API_V1_STR}/report/{report_id}/progress", headers={"Authorization": f"Bearer {auth_token}"}, ) assert response.status_code == 200 progress_data = response.json() assert progress_data["report_id"] == report_id assert "progress" in progress_data assert "status" in progress_data def test_get_report_list(auth_token): """Test getting a list of reports.""" response = client.get( f"{settings.API_V1_STR}/report/list", headers={"Authorization": f"Bearer {auth_token}"}, ) assert response.status_code == 200 data = response.json() assert "reports" in data assert "total" in data assert isinstance(data["reports"], list) assert isinstance(data["total"], int) def test_get_report(auth_token): """Test getting a specific report.""" # First, get the list of reports to get a report_id response = client.get( f"{settings.API_V1_STR}/report/list", headers={"Authorization": f"Bearer {auth_token}"}, ) assert response.status_code == 200 list_data = response.json() assert len(list_data["reports"]) > 0 report_id = list_data["reports"][0]["id"] # Now get the specific report response = client.get( f"{settings.API_V1_STR}/report/{report_id}", headers={"Authorization": f"Bearer {auth_token}"}, ) assert response.status_code == 200 data = response.json() assert data["id"] == report_id assert "title" in data assert "content" in data assert "detail_level" in data assert "query_type" in data assert "model_used" in data def test_download_report(auth_token): """Test downloading a report.""" # First, execute a search to get a search_id response = client.post( f"{settings.API_V1_STR}/search/execute", json={ "structured_query": { "original_query": "What are the environmental impacts of electric vehicles?", "enhanced_query": "What are the environmental impacts of electric vehicles?", "type": "comparative", "domain": "environmental,economic", }, "search_engines": ["serper"], "num_results": 2, "timeout": 10, }, headers={"Authorization": f"Bearer {auth_token}"}, ) assert response.status_code == 200 search_data = response.json() search_id = search_data["search_id"] # Now generate a report response = client.post( f"{settings.API_V1_STR}/report/generate", json={ "search_id": search_id, "query": "What are the environmental impacts of electric vehicles?", "detail_level": "brief", "query_type": "comparative", "model": "llama-3.1-8b-instant", }, headers={"Authorization": f"Bearer {auth_token}"}, ) assert response.status_code == 200 report_data = response.json() report_id = report_data["id"] # Now download the report in markdown format response = client.get( f"{settings.API_V1_STR}/report/{report_id}/download?format=markdown", headers={"Authorization": f"Bearer {auth_token}"}, ) assert response.status_code == 200 assert response.headers["content-type"] == "application/octet-stream" assert response.headers["content-disposition"] == f'attachment; filename="report_{report_id}.markdown"' # Now download the report in HTML format response = client.get( f"{settings.API_V1_STR}/report/{report_id}/download?format=html", headers={"Authorization": f"Bearer {auth_token}"}, ) assert response.status_code == 200 assert response.headers["content-type"] == "application/octet-stream" assert response.headers["content-disposition"] == f'attachment; filename="report_{report_id}.html"' def test_delete_report(auth_token): """Test deleting a report.""" # First, get the list of reports to get a report_id response = client.get( f"{settings.API_V1_STR}/report/list", headers={"Authorization": f"Bearer {auth_token}"}, ) assert response.status_code == 200 list_data = response.json() assert len(list_data["reports"]) > 0 report_id = list_data["reports"][0]["id"] # Now delete the report response = client.delete( f"{settings.API_V1_STR}/report/{report_id}", headers={"Authorization": f"Bearer {auth_token}"}, ) assert response.status_code == 204 # Verify that the report is deleted response = client.get( f"{settings.API_V1_STR}/report/{report_id}", headers={"Authorization": f"Bearer {auth_token}"}, ) assert response.status_code == 404 def test_delete_search(auth_token): """Test deleting a search.""" # First, get the list of searches to get a search_id response = client.get( f"{settings.API_V1_STR}/search/history", headers={"Authorization": f"Bearer {auth_token}"}, ) assert response.status_code == 200 list_data = response.json() assert len(list_data["searches"]) > 0 search_id = list_data["searches"][0]["id"] # Now delete the search response = client.delete( f"{settings.API_V1_STR}/search/{search_id}", headers={"Authorization": f"Bearer {auth_token}"}, ) assert response.status_code == 204 # Verify that the search is deleted response = client.get( f"{settings.API_V1_STR}/search/{search_id}", headers={"Authorization": f"Bearer {auth_token}"}, ) assert response.status_code == 404 if __name__ == "__main__": pytest.main(["-xvs", __file__])