chatterbox-ui/backend/test_phase1.py

197 lines
6.5 KiB
Python

#!/usr/bin/env python3
"""
Test script for Phase 1 implementation - Abstract base class and data models
"""
import sys
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.append(str(project_root))
from backend.app.models.tts_models import (
TTSParameters, SpeakerConfig, OutputConfig, TTSRequest, TTSResponse
)
from backend.app.services.base_tts_service import BaseTTSService, TTSError
from backend.app import config
def test_data_models():
"""Test TTS data models"""
print("Testing TTS data models...")
# Test TTSParameters
params = TTSParameters(
temperature=0.8,
backend_params={"max_new_tokens": 512, "top_p": 0.9}
)
assert params.temperature == 0.8
assert params.backend_params["max_new_tokens"] == 512
print("✓ TTSParameters working correctly")
# Test SpeakerConfig for chatterbox backend
speaker_config_chatterbox = SpeakerConfig(
id="test-speaker-1",
name="Test Speaker",
sample_path="/tmp/test_sample.wav",
tts_backend="chatterbox"
)
print("✓ SpeakerConfig for chatterbox backend working")
# Test SpeakerConfig validation for higgs backend (should raise error without reference_text)
try:
speaker_config_higgs_invalid = SpeakerConfig(
id="test-speaker-2",
name="Invalid Higgs Speaker",
sample_path="/tmp/test_sample.wav",
tts_backend="higgs"
)
speaker_config_higgs_invalid.validate()
assert False, "Should have raised ValueError for missing reference_text"
except ValueError as e:
print("✓ SpeakerConfig validation correctly catches missing reference_text for higgs")
# Test valid SpeakerConfig for higgs backend
speaker_config_higgs_valid = SpeakerConfig(
id="test-speaker-3",
name="Valid Higgs Speaker",
sample_path="/tmp/test_sample.wav",
reference_text="Hello, this is a test.",
tts_backend="higgs"
)
speaker_config_higgs_valid.validate() # Should not raise
print("✓ SpeakerConfig for higgs backend with reference_text working")
# Test OutputConfig
output_config = OutputConfig(
filename_base="test_output",
output_dir=Path("/tmp"),
format="wav"
)
assert output_config.filename_base == "test_output"
print("✓ OutputConfig working correctly")
# Test TTSRequest
request = TTSRequest(
text="Hello world, this is a test.",
speaker_config=speaker_config_chatterbox,
parameters=params,
output_config=output_config
)
assert request.text == "Hello world, this is a test."
assert request.speaker_config.name == "Test Speaker"
print("✓ TTSRequest working correctly")
# Test TTSResponse
response = TTSResponse(
output_path=Path("/tmp/output.wav"),
generated_text="Hello world, this is a test.",
audio_duration=3.5,
sampling_rate=22050,
backend_used="chatterbox"
)
assert response.audio_duration == 3.5
assert response.backend_used == "chatterbox"
print("✓ TTSResponse working correctly")
def test_base_service():
"""Test abstract base service class"""
print("\nTesting abstract base service...")
# Create a mock implementation
class MockTTSService(BaseTTSService):
async def load_model(self):
self.model = "mock_model_loaded"
async def unload_model(self):
self.model = None
async def generate_speech(self, request):
return TTSResponse(
output_path=Path("/tmp/mock_output.wav"),
backend_used=self.backend_name
)
def validate_speaker_config(self, config):
return True
# Test device resolution
mock_service = MockTTSService(device="auto")
assert mock_service.device in ["cuda", "mps", "cpu"]
print(f"✓ Device auto-resolution: {mock_service.device}")
# Test backend name extraction
assert mock_service.backend_name == "mock"
print("✓ Backend name extraction working")
# Test model loading state
assert not mock_service.is_loaded()
print("✓ Initial model state check")
def test_configuration():
"""Test configuration values"""
print("\nTesting configuration...")
assert hasattr(config, 'HIGGS_MODEL_PATH')
assert hasattr(config, 'HIGGS_AUDIO_TOKENIZER_PATH')
assert hasattr(config, 'DEFAULT_TTS_BACKEND')
assert hasattr(config, 'TTS_BACKEND_DEFAULTS')
print(f"✓ Default TTS backend: {config.DEFAULT_TTS_BACKEND}")
print(f"✓ Higgs model path: {config.HIGGS_MODEL_PATH}")
# Test backend defaults
assert "chatterbox" in config.TTS_BACKEND_DEFAULTS
assert "higgs" in config.TTS_BACKEND_DEFAULTS
assert "temperature" in config.TTS_BACKEND_DEFAULTS["chatterbox"]
assert "max_new_tokens" in config.TTS_BACKEND_DEFAULTS["higgs"]
print("✓ TTS backend defaults configured correctly")
def test_error_handling():
"""Test TTS error classes"""
print("\nTesting error handling...")
# Test TTSError
try:
raise TTSError("Test error", "test_backend", "ERROR_001")
except TTSError as e:
assert e.backend == "test_backend"
assert e.error_code == "ERROR_001"
print("✓ TTSError working correctly")
# Test BackendSpecificError inheritance
from backend.app.services.base_tts_service import BackendSpecificError
try:
raise BackendSpecificError("Backend specific error", "higgs")
except TTSError as e: # Should catch as base class
assert e.backend == "higgs"
print("✓ BackendSpecificError inheritance working correctly")
def main():
"""Run all tests"""
print("=== Phase 1 Implementation Tests ===\n")
try:
test_data_models()
test_base_service()
test_configuration()
test_error_handling()
print("\n=== All Phase 1 tests passed! ✓ ===")
print("\nPhase 1 components ready:")
print("- TTS data models (TTSRequest, TTSResponse, etc.)")
print("- Abstract BaseTTSService class")
print("- Configuration system with Higgs support")
print("- Error handling framework")
print("\nReady to proceed to Phase 2: Service Implementation")
except Exception as e:
print(f"\n❌ Test failed: {e}")
import traceback
traceback.print_exc()
return 1
return 0
if __name__ == "__main__":
exit(main())