#!/usr/bin/env python3 """ Test script for Phase 1 implementation - Abstract base class and data models """ import sys from pathlib import Path # Add project root to path project_root = Path(__file__).parent.parent sys.path.append(str(project_root)) from backend.app.models.tts_models import ( TTSParameters, SpeakerConfig, OutputConfig, TTSRequest, TTSResponse ) from backend.app.services.base_tts_service import BaseTTSService, TTSError from backend.app import config def test_data_models(): """Test TTS data models""" print("Testing TTS data models...") # Test TTSParameters params = TTSParameters( temperature=0.8, backend_params={"max_new_tokens": 512, "top_p": 0.9} ) assert params.temperature == 0.8 assert params.backend_params["max_new_tokens"] == 512 print("✓ TTSParameters working correctly") # Test SpeakerConfig for chatterbox backend speaker_config_chatterbox = SpeakerConfig( id="test-speaker-1", name="Test Speaker", sample_path="/tmp/test_sample.wav", tts_backend="chatterbox" ) print("✓ SpeakerConfig for chatterbox backend working") # Test SpeakerConfig validation for higgs backend (should raise error without reference_text) try: speaker_config_higgs_invalid = SpeakerConfig( id="test-speaker-2", name="Invalid Higgs Speaker", sample_path="/tmp/test_sample.wav", tts_backend="higgs" ) speaker_config_higgs_invalid.validate() assert False, "Should have raised ValueError for missing reference_text" except ValueError as e: print("✓ SpeakerConfig validation correctly catches missing reference_text for higgs") # Test valid SpeakerConfig for higgs backend speaker_config_higgs_valid = SpeakerConfig( id="test-speaker-3", name="Valid Higgs Speaker", sample_path="/tmp/test_sample.wav", reference_text="Hello, this is a test.", tts_backend="higgs" ) speaker_config_higgs_valid.validate() # Should not raise print("✓ SpeakerConfig for higgs backend with reference_text working") # Test OutputConfig output_config = OutputConfig( filename_base="test_output", output_dir=Path("/tmp"), format="wav" ) assert output_config.filename_base == "test_output" print("✓ OutputConfig working correctly") # Test TTSRequest request = TTSRequest( text="Hello world, this is a test.", speaker_config=speaker_config_chatterbox, parameters=params, output_config=output_config ) assert request.text == "Hello world, this is a test." assert request.speaker_config.name == "Test Speaker" print("✓ TTSRequest working correctly") # Test TTSResponse response = TTSResponse( output_path=Path("/tmp/output.wav"), generated_text="Hello world, this is a test.", audio_duration=3.5, sampling_rate=22050, backend_used="chatterbox" ) assert response.audio_duration == 3.5 assert response.backend_used == "chatterbox" print("✓ TTSResponse working correctly") def test_base_service(): """Test abstract base service class""" print("\nTesting abstract base service...") # Create a mock implementation class MockTTSService(BaseTTSService): async def load_model(self): self.model = "mock_model_loaded" async def unload_model(self): self.model = None async def generate_speech(self, request): return TTSResponse( output_path=Path("/tmp/mock_output.wav"), backend_used=self.backend_name ) def validate_speaker_config(self, config): return True # Test device resolution mock_service = MockTTSService(device="auto") assert mock_service.device in ["cuda", "mps", "cpu"] print(f"✓ Device auto-resolution: {mock_service.device}") # Test backend name extraction assert mock_service.backend_name == "mock" print("✓ Backend name extraction working") # Test model loading state assert not mock_service.is_loaded() print("✓ Initial model state check") def test_configuration(): """Test configuration values""" print("\nTesting configuration...") assert hasattr(config, 'HIGGS_MODEL_PATH') assert hasattr(config, 'HIGGS_AUDIO_TOKENIZER_PATH') assert hasattr(config, 'DEFAULT_TTS_BACKEND') assert hasattr(config, 'TTS_BACKEND_DEFAULTS') print(f"✓ Default TTS backend: {config.DEFAULT_TTS_BACKEND}") print(f"✓ Higgs model path: {config.HIGGS_MODEL_PATH}") # Test backend defaults assert "chatterbox" in config.TTS_BACKEND_DEFAULTS assert "higgs" in config.TTS_BACKEND_DEFAULTS assert "temperature" in config.TTS_BACKEND_DEFAULTS["chatterbox"] assert "max_new_tokens" in config.TTS_BACKEND_DEFAULTS["higgs"] print("✓ TTS backend defaults configured correctly") def test_error_handling(): """Test TTS error classes""" print("\nTesting error handling...") # Test TTSError try: raise TTSError("Test error", "test_backend", "ERROR_001") except TTSError as e: assert e.backend == "test_backend" assert e.error_code == "ERROR_001" print("✓ TTSError working correctly") # Test BackendSpecificError inheritance from backend.app.services.base_tts_service import BackendSpecificError try: raise BackendSpecificError("Backend specific error", "higgs") except TTSError as e: # Should catch as base class assert e.backend == "higgs" print("✓ BackendSpecificError inheritance working correctly") def main(): """Run all tests""" print("=== Phase 1 Implementation Tests ===\n") try: test_data_models() test_base_service() test_configuration() test_error_handling() print("\n=== All Phase 1 tests passed! ✓ ===") print("\nPhase 1 components ready:") print("- TTS data models (TTSRequest, TTSResponse, etc.)") print("- Abstract BaseTTSService class") print("- Configuration system with Higgs support") print("- Error handling framework") print("\nReady to proceed to Phase 2: Service Implementation") except Exception as e: print(f"\n❌ Test failed: {e}") import traceback traceback.print_exc() return 1 return 0 if __name__ == "__main__": exit(main())