296 lines
11 KiB
Python
296 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test script for Phase 2 implementation - Service implementations and factory
|
|
"""
|
|
import sys
|
|
import asyncio
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
# Add project root to path
|
|
project_root = Path(__file__).parent.parent
|
|
sys.path.append(str(project_root))
|
|
|
|
from backend.app.models.tts_models import (
|
|
TTSParameters, SpeakerConfig, OutputConfig, TTSRequest, TTSResponse
|
|
)
|
|
from backend.app.services.chatterbox_tts_service import ChatterboxTTSService
|
|
from backend.app.services.higgs_tts_service import HiggsTTSService
|
|
from backend.app.services.tts_factory import TTSServiceFactory, get_tts_service, list_available_backends
|
|
from backend.app.services.base_tts_service import TTSError
|
|
from backend.app import config
|
|
|
|
def test_chatterbox_service():
|
|
"""Test ChatterboxTTSService implementation"""
|
|
print("Testing ChatterboxTTSService...")
|
|
|
|
# Test service creation
|
|
service = ChatterboxTTSService(device="auto")
|
|
assert service.backend_name == "chatterbox"
|
|
assert service.device in ["cuda", "mps", "cpu"]
|
|
assert not service.is_loaded()
|
|
print(f"✓ ChatterboxTTSService created with device: {service.device}")
|
|
|
|
# Test speaker validation - valid chatterbox speaker
|
|
valid_speaker = SpeakerConfig(
|
|
id="test-chatterbox",
|
|
name="Test Chatterbox Speaker",
|
|
sample_path="speaker_samples/test.wav", # Relative path
|
|
tts_backend="chatterbox"
|
|
)
|
|
# Note: validation will fail due to missing file, but should not crash
|
|
result = service.validate_speaker_config(valid_speaker)
|
|
print(f"✓ Speaker validation (expected to fail due to missing file): {result}")
|
|
|
|
# Test speaker validation - wrong backend
|
|
wrong_backend_speaker = SpeakerConfig(
|
|
id="test-higgs",
|
|
name="Test Higgs Speaker",
|
|
sample_path="test.wav",
|
|
tts_backend="higgs"
|
|
)
|
|
assert not service.validate_speaker_config(wrong_backend_speaker)
|
|
print("✓ Chatterbox service correctly rejects Higgs speaker")
|
|
|
|
def test_higgs_service():
|
|
"""Test HiggsTTSService implementation"""
|
|
print("\nTesting HiggsTTSService...")
|
|
|
|
# Test service creation
|
|
service = HiggsTTSService(device="auto")
|
|
assert service.backend_name == "higgs"
|
|
assert service.device in ["cuda", "mps", "cpu"]
|
|
assert not service.is_loaded()
|
|
print(f"✓ HiggsTTSService created with device: {service.device}")
|
|
|
|
# Test model info
|
|
info = service.get_model_info()
|
|
assert info["backend"] == "higgs"
|
|
assert "dependencies_available" in info
|
|
print(f"✓ Higgs model info: dependencies_available={info['dependencies_available']}")
|
|
|
|
# Test speaker validation - valid higgs speaker
|
|
valid_speaker = SpeakerConfig(
|
|
id="test-higgs",
|
|
name="Test Higgs Speaker",
|
|
sample_path="speaker_samples/test.wav",
|
|
reference_text="Hello, this is a test reference.",
|
|
tts_backend="higgs"
|
|
)
|
|
# Note: validation will fail due to missing file
|
|
result = service.validate_speaker_config(valid_speaker)
|
|
print(f"✓ Higgs speaker validation (expected to fail due to missing file): {result}")
|
|
|
|
# Test speaker validation - missing reference text
|
|
invalid_speaker = SpeakerConfig(
|
|
id="test-invalid",
|
|
name="Invalid Speaker",
|
|
sample_path="test.wav",
|
|
tts_backend="higgs" # Missing reference_text
|
|
)
|
|
assert not service.validate_speaker_config(invalid_speaker)
|
|
print("✓ Higgs service correctly rejects speaker without reference_text")
|
|
|
|
def test_factory_pattern():
|
|
"""Test TTSServiceFactory"""
|
|
print("\nTesting TTSServiceFactory...")
|
|
|
|
# Test available backends
|
|
backends = TTSServiceFactory.get_available_backends()
|
|
assert "chatterbox" in backends
|
|
assert "higgs" in backends
|
|
print(f"✓ Available backends: {backends}")
|
|
|
|
# Test service creation
|
|
chatterbox_service = TTSServiceFactory.create_service("chatterbox")
|
|
assert isinstance(chatterbox_service, ChatterboxTTSService)
|
|
assert chatterbox_service.backend_name == "chatterbox"
|
|
print("✓ Factory creates ChatterboxTTSService correctly")
|
|
|
|
higgs_service = TTSServiceFactory.create_service("higgs")
|
|
assert isinstance(higgs_service, HiggsTTSService)
|
|
assert higgs_service.backend_name == "higgs"
|
|
print("✓ Factory creates HiggsTTSService correctly")
|
|
|
|
# Test singleton behavior
|
|
chatterbox_service2 = TTSServiceFactory.create_service("chatterbox")
|
|
assert chatterbox_service is chatterbox_service2
|
|
print("✓ Factory singleton behavior working")
|
|
|
|
# Test unknown backend
|
|
try:
|
|
TTSServiceFactory.create_service("unknown_backend")
|
|
assert False, "Should have raised TTSError"
|
|
except TTSError as e:
|
|
assert e.backend == "unknown_backend"
|
|
print("✓ Factory correctly handles unknown backend")
|
|
|
|
# Test backend info
|
|
info = TTSServiceFactory.get_backend_info()
|
|
assert "chatterbox" in info
|
|
assert "higgs" in info
|
|
print("✓ Backend info retrieval working")
|
|
|
|
# Test service stats
|
|
stats = TTSServiceFactory.get_service_stats()
|
|
assert stats["total_backends"] >= 2
|
|
assert "chatterbox" in stats["backends"]
|
|
print(f"✓ Service stats: {stats['total_backends']} backends, {stats['loaded_instances']} instances")
|
|
|
|
def test_utility_functions():
|
|
"""Test utility functions"""
|
|
print("\nTesting utility functions...")
|
|
|
|
# Test list_available_backends
|
|
backends = list_available_backends()
|
|
assert isinstance(backends, list)
|
|
assert "chatterbox" in backends
|
|
print(f"✓ list_available_backends: {backends}")
|
|
|
|
async def test_async_operations():
|
|
"""Test async service operations"""
|
|
print("\nTesting async operations...")
|
|
|
|
# Test get_tts_service utility
|
|
service = await get_tts_service("chatterbox")
|
|
assert isinstance(service, ChatterboxTTSService)
|
|
print("✓ get_tts_service utility working")
|
|
|
|
# Test service lifecycle (without actually loading heavy models)
|
|
print("✓ Async service creation working (model loading skipped for test)")
|
|
|
|
def test_parameter_handling():
|
|
"""Test parameter mapping and defaults"""
|
|
print("\nTesting parameter handling...")
|
|
|
|
# Test chatterbox parameters
|
|
chatterbox_params = TTSParameters(
|
|
temperature=0.7,
|
|
backend_params=config.TTS_BACKEND_DEFAULTS["chatterbox"]
|
|
)
|
|
assert chatterbox_params.backend_params["exaggeration"] == 0.5
|
|
assert chatterbox_params.backend_params["cfg_weight"] == 0.5
|
|
print("✓ Chatterbox parameter defaults loaded")
|
|
|
|
# Test higgs parameters
|
|
higgs_params = TTSParameters(
|
|
temperature=0.9,
|
|
backend_params=config.TTS_BACKEND_DEFAULTS["higgs"]
|
|
)
|
|
assert higgs_params.backend_params["max_new_tokens"] == 1024
|
|
assert higgs_params.backend_params["top_p"] == 0.95
|
|
print("✓ Higgs parameter defaults loaded")
|
|
|
|
def test_request_response_flow():
|
|
"""Test complete request/response flow (without actual generation)"""
|
|
print("\nTesting request/response flow...")
|
|
|
|
# Create test speaker config
|
|
speaker = SpeakerConfig(
|
|
id="test-speaker",
|
|
name="Test Speaker",
|
|
sample_path="speaker_samples/test.wav",
|
|
tts_backend="chatterbox"
|
|
)
|
|
|
|
# Create test parameters
|
|
params = TTSParameters(
|
|
temperature=0.8,
|
|
backend_params=config.TTS_BACKEND_DEFAULTS["chatterbox"]
|
|
)
|
|
|
|
# Create test output config
|
|
output = OutputConfig(
|
|
filename_base="test_generation",
|
|
output_dir=Path(tempfile.gettempdir()),
|
|
format="wav"
|
|
)
|
|
|
|
# Create test request
|
|
request = TTSRequest(
|
|
text="Hello, this is a test generation.",
|
|
speaker_config=speaker,
|
|
parameters=params,
|
|
output_config=output
|
|
)
|
|
|
|
assert request.text == "Hello, this is a test generation."
|
|
assert request.speaker_config.tts_backend == "chatterbox"
|
|
assert request.parameters.backend_params["exaggeration"] == 0.5
|
|
print("✓ TTS request creation working correctly")
|
|
|
|
async def test_error_handling():
|
|
"""Test error handling in services"""
|
|
print("\nTesting error handling...")
|
|
|
|
service = TTSServiceFactory.create_service("higgs")
|
|
|
|
# Test handling of missing dependencies (if Higgs not installed)
|
|
try:
|
|
await service.load_model()
|
|
print("✓ Higgs model loading (dependencies available)")
|
|
except TTSError as e:
|
|
if e.error_code == "MISSING_DEPENDENCIES":
|
|
print("✓ Higgs service correctly handles missing dependencies")
|
|
else:
|
|
print(f"✓ Higgs service error handling: {e}")
|
|
|
|
def test_service_registration():
|
|
"""Test custom service registration"""
|
|
print("\nTesting service registration...")
|
|
|
|
# Create a mock custom service
|
|
from backend.app.services.base_tts_service import BaseTTSService
|
|
from backend.app.models.tts_models import TTSRequest, TTSResponse
|
|
|
|
class CustomTTSService(BaseTTSService):
|
|
async def load_model(self): pass
|
|
async def unload_model(self): pass
|
|
async def generate_speech(self, request: TTSRequest) -> TTSResponse:
|
|
return TTSResponse(output_path=Path("/tmp/custom.wav"), backend_used="custom")
|
|
def validate_speaker_config(self, config): return True
|
|
|
|
# Register custom service
|
|
TTSServiceFactory.register_service("custom", CustomTTSService)
|
|
|
|
# Test creation
|
|
custom_service = TTSServiceFactory.create_service("custom")
|
|
assert isinstance(custom_service, CustomTTSService)
|
|
assert custom_service.backend_name == "custom"
|
|
print("✓ Custom service registration working")
|
|
|
|
async def main():
|
|
"""Run all Phase 2 tests"""
|
|
print("=== Phase 2 Implementation Tests ===\n")
|
|
|
|
try:
|
|
test_chatterbox_service()
|
|
test_higgs_service()
|
|
test_factory_pattern()
|
|
test_utility_functions()
|
|
await test_async_operations()
|
|
test_parameter_handling()
|
|
test_request_response_flow()
|
|
await test_error_handling()
|
|
test_service_registration()
|
|
|
|
print("\n=== All Phase 2 tests passed! ✓ ===")
|
|
print("\nPhase 2 components ready:")
|
|
print("- ChatterboxTTSService (refactored with abstract base)")
|
|
print("- HiggsTTSService (with voice cloning support)")
|
|
print("- TTSServiceFactory (singleton pattern with lifecycle management)")
|
|
print("- Error handling for missing dependencies")
|
|
print("- Parameter mapping for different backends")
|
|
print("- Service registration for extensibility")
|
|
print("\nReady to proceed to Phase 3: Enhanced Data Models and Validation")
|
|
|
|
return 0
|
|
|
|
except Exception as e:
|
|
print(f"\n❌ Test failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return 1
|
|
|
|
if __name__ == "__main__":
|
|
exit(asyncio.run(main())) |