#!/usr/bin/env python3 """ Test script for Phase 2 implementation - Service implementations and factory """ import sys import asyncio import tempfile from pathlib import Path # Add project root to path project_root = Path(__file__).parent.parent sys.path.append(str(project_root)) from backend.app.models.tts_models import ( TTSParameters, SpeakerConfig, OutputConfig, TTSRequest, TTSResponse ) from backend.app.services.chatterbox_tts_service import ChatterboxTTSService from backend.app.services.higgs_tts_service import HiggsTTSService from backend.app.services.tts_factory import TTSServiceFactory, get_tts_service, list_available_backends from backend.app.services.base_tts_service import TTSError from backend.app import config def test_chatterbox_service(): """Test ChatterboxTTSService implementation""" print("Testing ChatterboxTTSService...") # Test service creation service = ChatterboxTTSService(device="auto") assert service.backend_name == "chatterbox" assert service.device in ["cuda", "mps", "cpu"] assert not service.is_loaded() print(f"✓ ChatterboxTTSService created with device: {service.device}") # Test speaker validation - valid chatterbox speaker valid_speaker = SpeakerConfig( id="test-chatterbox", name="Test Chatterbox Speaker", sample_path="speaker_samples/test.wav", # Relative path tts_backend="chatterbox" ) # Note: validation will fail due to missing file, but should not crash result = service.validate_speaker_config(valid_speaker) print(f"✓ Speaker validation (expected to fail due to missing file): {result}") # Test speaker validation - wrong backend wrong_backend_speaker = SpeakerConfig( id="test-higgs", name="Test Higgs Speaker", sample_path="test.wav", tts_backend="higgs" ) assert not service.validate_speaker_config(wrong_backend_speaker) print("✓ Chatterbox service correctly rejects Higgs speaker") def test_higgs_service(): """Test HiggsTTSService implementation""" print("\nTesting HiggsTTSService...") # Test service creation service = HiggsTTSService(device="auto") assert service.backend_name == "higgs" assert service.device in ["cuda", "mps", "cpu"] assert not service.is_loaded() print(f"✓ HiggsTTSService created with device: {service.device}") # Test model info info = service.get_model_info() assert info["backend"] == "higgs" assert "dependencies_available" in info print(f"✓ Higgs model info: dependencies_available={info['dependencies_available']}") # Test speaker validation - valid higgs speaker valid_speaker = SpeakerConfig( id="test-higgs", name="Test Higgs Speaker", sample_path="speaker_samples/test.wav", reference_text="Hello, this is a test reference.", tts_backend="higgs" ) # Note: validation will fail due to missing file result = service.validate_speaker_config(valid_speaker) print(f"✓ Higgs speaker validation (expected to fail due to missing file): {result}") # Test speaker validation - missing reference text invalid_speaker = SpeakerConfig( id="test-invalid", name="Invalid Speaker", sample_path="test.wav", tts_backend="higgs" # Missing reference_text ) assert not service.validate_speaker_config(invalid_speaker) print("✓ Higgs service correctly rejects speaker without reference_text") def test_factory_pattern(): """Test TTSServiceFactory""" print("\nTesting TTSServiceFactory...") # Test available backends backends = TTSServiceFactory.get_available_backends() assert "chatterbox" in backends assert "higgs" in backends print(f"✓ Available backends: {backends}") # Test service creation chatterbox_service = TTSServiceFactory.create_service("chatterbox") assert isinstance(chatterbox_service, ChatterboxTTSService) assert chatterbox_service.backend_name == "chatterbox" print("✓ Factory creates ChatterboxTTSService correctly") higgs_service = TTSServiceFactory.create_service("higgs") assert isinstance(higgs_service, HiggsTTSService) assert higgs_service.backend_name == "higgs" print("✓ Factory creates HiggsTTSService correctly") # Test singleton behavior chatterbox_service2 = TTSServiceFactory.create_service("chatterbox") assert chatterbox_service is chatterbox_service2 print("✓ Factory singleton behavior working") # Test unknown backend try: TTSServiceFactory.create_service("unknown_backend") assert False, "Should have raised TTSError" except TTSError as e: assert e.backend == "unknown_backend" print("✓ Factory correctly handles unknown backend") # Test backend info info = TTSServiceFactory.get_backend_info() assert "chatterbox" in info assert "higgs" in info print("✓ Backend info retrieval working") # Test service stats stats = TTSServiceFactory.get_service_stats() assert stats["total_backends"] >= 2 assert "chatterbox" in stats["backends"] print(f"✓ Service stats: {stats['total_backends']} backends, {stats['loaded_instances']} instances") def test_utility_functions(): """Test utility functions""" print("\nTesting utility functions...") # Test list_available_backends backends = list_available_backends() assert isinstance(backends, list) assert "chatterbox" in backends print(f"✓ list_available_backends: {backends}") async def test_async_operations(): """Test async service operations""" print("\nTesting async operations...") # Test get_tts_service utility service = await get_tts_service("chatterbox") assert isinstance(service, ChatterboxTTSService) print("✓ get_tts_service utility working") # Test service lifecycle (without actually loading heavy models) print("✓ Async service creation working (model loading skipped for test)") def test_parameter_handling(): """Test parameter mapping and defaults""" print("\nTesting parameter handling...") # Test chatterbox parameters chatterbox_params = TTSParameters( temperature=0.7, backend_params=config.TTS_BACKEND_DEFAULTS["chatterbox"] ) assert chatterbox_params.backend_params["exaggeration"] == 0.5 assert chatterbox_params.backend_params["cfg_weight"] == 0.5 print("✓ Chatterbox parameter defaults loaded") # Test higgs parameters higgs_params = TTSParameters( temperature=0.9, backend_params=config.TTS_BACKEND_DEFAULTS["higgs"] ) assert higgs_params.backend_params["max_new_tokens"] == 1024 assert higgs_params.backend_params["top_p"] == 0.95 print("✓ Higgs parameter defaults loaded") def test_request_response_flow(): """Test complete request/response flow (without actual generation)""" print("\nTesting request/response flow...") # Create test speaker config speaker = SpeakerConfig( id="test-speaker", name="Test Speaker", sample_path="speaker_samples/test.wav", tts_backend="chatterbox" ) # Create test parameters params = TTSParameters( temperature=0.8, backend_params=config.TTS_BACKEND_DEFAULTS["chatterbox"] ) # Create test output config output = OutputConfig( filename_base="test_generation", output_dir=Path(tempfile.gettempdir()), format="wav" ) # Create test request request = TTSRequest( text="Hello, this is a test generation.", speaker_config=speaker, parameters=params, output_config=output ) assert request.text == "Hello, this is a test generation." assert request.speaker_config.tts_backend == "chatterbox" assert request.parameters.backend_params["exaggeration"] == 0.5 print("✓ TTS request creation working correctly") async def test_error_handling(): """Test error handling in services""" print("\nTesting error handling...") service = TTSServiceFactory.create_service("higgs") # Test handling of missing dependencies (if Higgs not installed) try: await service.load_model() print("✓ Higgs model loading (dependencies available)") except TTSError as e: if e.error_code == "MISSING_DEPENDENCIES": print("✓ Higgs service correctly handles missing dependencies") else: print(f"✓ Higgs service error handling: {e}") def test_service_registration(): """Test custom service registration""" print("\nTesting service registration...") # Create a mock custom service from backend.app.services.base_tts_service import BaseTTSService from backend.app.models.tts_models import TTSRequest, TTSResponse class CustomTTSService(BaseTTSService): async def load_model(self): pass async def unload_model(self): pass async def generate_speech(self, request: TTSRequest) -> TTSResponse: return TTSResponse(output_path=Path("/tmp/custom.wav"), backend_used="custom") def validate_speaker_config(self, config): return True # Register custom service TTSServiceFactory.register_service("custom", CustomTTSService) # Test creation custom_service = TTSServiceFactory.create_service("custom") assert isinstance(custom_service, CustomTTSService) assert custom_service.backend_name == "custom" print("✓ Custom service registration working") async def main(): """Run all Phase 2 tests""" print("=== Phase 2 Implementation Tests ===\n") try: test_chatterbox_service() test_higgs_service() test_factory_pattern() test_utility_functions() await test_async_operations() test_parameter_handling() test_request_response_flow() await test_error_handling() test_service_registration() print("\n=== All Phase 2 tests passed! ✓ ===") print("\nPhase 2 components ready:") print("- ChatterboxTTSService (refactored with abstract base)") print("- HiggsTTSService (with voice cloning support)") print("- TTSServiceFactory (singleton pattern with lifecycle management)") print("- Error handling for missing dependencies") print("- Parameter mapping for different backends") print("- Service registration for extensibility") print("\nReady to proceed to Phase 3: Enhanced Data Models and Validation") return 0 except Exception as e: print(f"\n❌ Test failed: {e}") import traceback traceback.print_exc() return 1 if __name__ == "__main__": exit(asyncio.run(main()))