chatterbox-ui/backend/test_phase2.py

296 lines
11 KiB
Python

#!/usr/bin/env python3
"""
Test script for Phase 2 implementation - Service implementations and factory
"""
import sys
import asyncio
import tempfile
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.append(str(project_root))
from backend.app.models.tts_models import (
TTSParameters, SpeakerConfig, OutputConfig, TTSRequest, TTSResponse
)
from backend.app.services.chatterbox_tts_service import ChatterboxTTSService
from backend.app.services.higgs_tts_service import HiggsTTSService
from backend.app.services.tts_factory import TTSServiceFactory, get_tts_service, list_available_backends
from backend.app.services.base_tts_service import TTSError
from backend.app import config
def test_chatterbox_service():
"""Test ChatterboxTTSService implementation"""
print("Testing ChatterboxTTSService...")
# Test service creation
service = ChatterboxTTSService(device="auto")
assert service.backend_name == "chatterbox"
assert service.device in ["cuda", "mps", "cpu"]
assert not service.is_loaded()
print(f"✓ ChatterboxTTSService created with device: {service.device}")
# Test speaker validation - valid chatterbox speaker
valid_speaker = SpeakerConfig(
id="test-chatterbox",
name="Test Chatterbox Speaker",
sample_path="speaker_samples/test.wav", # Relative path
tts_backend="chatterbox"
)
# Note: validation will fail due to missing file, but should not crash
result = service.validate_speaker_config(valid_speaker)
print(f"✓ Speaker validation (expected to fail due to missing file): {result}")
# Test speaker validation - wrong backend
wrong_backend_speaker = SpeakerConfig(
id="test-higgs",
name="Test Higgs Speaker",
sample_path="test.wav",
tts_backend="higgs"
)
assert not service.validate_speaker_config(wrong_backend_speaker)
print("✓ Chatterbox service correctly rejects Higgs speaker")
def test_higgs_service():
"""Test HiggsTTSService implementation"""
print("\nTesting HiggsTTSService...")
# Test service creation
service = HiggsTTSService(device="auto")
assert service.backend_name == "higgs"
assert service.device in ["cuda", "mps", "cpu"]
assert not service.is_loaded()
print(f"✓ HiggsTTSService created with device: {service.device}")
# Test model info
info = service.get_model_info()
assert info["backend"] == "higgs"
assert "dependencies_available" in info
print(f"✓ Higgs model info: dependencies_available={info['dependencies_available']}")
# Test speaker validation - valid higgs speaker
valid_speaker = SpeakerConfig(
id="test-higgs",
name="Test Higgs Speaker",
sample_path="speaker_samples/test.wav",
reference_text="Hello, this is a test reference.",
tts_backend="higgs"
)
# Note: validation will fail due to missing file
result = service.validate_speaker_config(valid_speaker)
print(f"✓ Higgs speaker validation (expected to fail due to missing file): {result}")
# Test speaker validation - missing reference text
invalid_speaker = SpeakerConfig(
id="test-invalid",
name="Invalid Speaker",
sample_path="test.wav",
tts_backend="higgs" # Missing reference_text
)
assert not service.validate_speaker_config(invalid_speaker)
print("✓ Higgs service correctly rejects speaker without reference_text")
def test_factory_pattern():
"""Test TTSServiceFactory"""
print("\nTesting TTSServiceFactory...")
# Test available backends
backends = TTSServiceFactory.get_available_backends()
assert "chatterbox" in backends
assert "higgs" in backends
print(f"✓ Available backends: {backends}")
# Test service creation
chatterbox_service = TTSServiceFactory.create_service("chatterbox")
assert isinstance(chatterbox_service, ChatterboxTTSService)
assert chatterbox_service.backend_name == "chatterbox"
print("✓ Factory creates ChatterboxTTSService correctly")
higgs_service = TTSServiceFactory.create_service("higgs")
assert isinstance(higgs_service, HiggsTTSService)
assert higgs_service.backend_name == "higgs"
print("✓ Factory creates HiggsTTSService correctly")
# Test singleton behavior
chatterbox_service2 = TTSServiceFactory.create_service("chatterbox")
assert chatterbox_service is chatterbox_service2
print("✓ Factory singleton behavior working")
# Test unknown backend
try:
TTSServiceFactory.create_service("unknown_backend")
assert False, "Should have raised TTSError"
except TTSError as e:
assert e.backend == "unknown_backend"
print("✓ Factory correctly handles unknown backend")
# Test backend info
info = TTSServiceFactory.get_backend_info()
assert "chatterbox" in info
assert "higgs" in info
print("✓ Backend info retrieval working")
# Test service stats
stats = TTSServiceFactory.get_service_stats()
assert stats["total_backends"] >= 2
assert "chatterbox" in stats["backends"]
print(f"✓ Service stats: {stats['total_backends']} backends, {stats['loaded_instances']} instances")
def test_utility_functions():
"""Test utility functions"""
print("\nTesting utility functions...")
# Test list_available_backends
backends = list_available_backends()
assert isinstance(backends, list)
assert "chatterbox" in backends
print(f"✓ list_available_backends: {backends}")
async def test_async_operations():
"""Test async service operations"""
print("\nTesting async operations...")
# Test get_tts_service utility
service = await get_tts_service("chatterbox")
assert isinstance(service, ChatterboxTTSService)
print("✓ get_tts_service utility working")
# Test service lifecycle (without actually loading heavy models)
print("✓ Async service creation working (model loading skipped for test)")
def test_parameter_handling():
"""Test parameter mapping and defaults"""
print("\nTesting parameter handling...")
# Test chatterbox parameters
chatterbox_params = TTSParameters(
temperature=0.7,
backend_params=config.TTS_BACKEND_DEFAULTS["chatterbox"]
)
assert chatterbox_params.backend_params["exaggeration"] == 0.5
assert chatterbox_params.backend_params["cfg_weight"] == 0.5
print("✓ Chatterbox parameter defaults loaded")
# Test higgs parameters
higgs_params = TTSParameters(
temperature=0.9,
backend_params=config.TTS_BACKEND_DEFAULTS["higgs"]
)
assert higgs_params.backend_params["max_new_tokens"] == 1024
assert higgs_params.backend_params["top_p"] == 0.95
print("✓ Higgs parameter defaults loaded")
def test_request_response_flow():
"""Test complete request/response flow (without actual generation)"""
print("\nTesting request/response flow...")
# Create test speaker config
speaker = SpeakerConfig(
id="test-speaker",
name="Test Speaker",
sample_path="speaker_samples/test.wav",
tts_backend="chatterbox"
)
# Create test parameters
params = TTSParameters(
temperature=0.8,
backend_params=config.TTS_BACKEND_DEFAULTS["chatterbox"]
)
# Create test output config
output = OutputConfig(
filename_base="test_generation",
output_dir=Path(tempfile.gettempdir()),
format="wav"
)
# Create test request
request = TTSRequest(
text="Hello, this is a test generation.",
speaker_config=speaker,
parameters=params,
output_config=output
)
assert request.text == "Hello, this is a test generation."
assert request.speaker_config.tts_backend == "chatterbox"
assert request.parameters.backend_params["exaggeration"] == 0.5
print("✓ TTS request creation working correctly")
async def test_error_handling():
"""Test error handling in services"""
print("\nTesting error handling...")
service = TTSServiceFactory.create_service("higgs")
# Test handling of missing dependencies (if Higgs not installed)
try:
await service.load_model()
print("✓ Higgs model loading (dependencies available)")
except TTSError as e:
if e.error_code == "MISSING_DEPENDENCIES":
print("✓ Higgs service correctly handles missing dependencies")
else:
print(f"✓ Higgs service error handling: {e}")
def test_service_registration():
"""Test custom service registration"""
print("\nTesting service registration...")
# Create a mock custom service
from backend.app.services.base_tts_service import BaseTTSService
from backend.app.models.tts_models import TTSRequest, TTSResponse
class CustomTTSService(BaseTTSService):
async def load_model(self): pass
async def unload_model(self): pass
async def generate_speech(self, request: TTSRequest) -> TTSResponse:
return TTSResponse(output_path=Path("/tmp/custom.wav"), backend_used="custom")
def validate_speaker_config(self, config): return True
# Register custom service
TTSServiceFactory.register_service("custom", CustomTTSService)
# Test creation
custom_service = TTSServiceFactory.create_service("custom")
assert isinstance(custom_service, CustomTTSService)
assert custom_service.backend_name == "custom"
print("✓ Custom service registration working")
async def main():
"""Run all Phase 2 tests"""
print("=== Phase 2 Implementation Tests ===\n")
try:
test_chatterbox_service()
test_higgs_service()
test_factory_pattern()
test_utility_functions()
await test_async_operations()
test_parameter_handling()
test_request_response_flow()
await test_error_handling()
test_service_registration()
print("\n=== All Phase 2 tests passed! ✓ ===")
print("\nPhase 2 components ready:")
print("- ChatterboxTTSService (refactored with abstract base)")
print("- HiggsTTSService (with voice cloning support)")
print("- TTSServiceFactory (singleton pattern with lifecycle management)")
print("- Error handling for missing dependencies")
print("- Parameter mapping for different backends")
print("- Service registration for extensibility")
print("\nReady to proceed to Phase 3: Enhanced Data Models and Validation")
return 0
except Exception as e:
print(f"\n❌ Test failed: {e}")
import traceback
traceback.print_exc()
return 1
if __name__ == "__main__":
exit(asyncio.run(main()))