197 lines
6.5 KiB
Python
197 lines
6.5 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test script for Phase 1 implementation - Abstract base class and data models
|
|
"""
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
# Add project root to path
|
|
project_root = Path(__file__).parent.parent
|
|
sys.path.append(str(project_root))
|
|
|
|
from backend.app.models.tts_models import (
|
|
TTSParameters, SpeakerConfig, OutputConfig, TTSRequest, TTSResponse
|
|
)
|
|
from backend.app.services.base_tts_service import BaseTTSService, TTSError
|
|
from backend.app import config
|
|
|
|
def test_data_models():
|
|
"""Test TTS data models"""
|
|
print("Testing TTS data models...")
|
|
|
|
# Test TTSParameters
|
|
params = TTSParameters(
|
|
temperature=0.8,
|
|
backend_params={"max_new_tokens": 512, "top_p": 0.9}
|
|
)
|
|
assert params.temperature == 0.8
|
|
assert params.backend_params["max_new_tokens"] == 512
|
|
print("✓ TTSParameters working correctly")
|
|
|
|
# Test SpeakerConfig for chatterbox backend
|
|
speaker_config_chatterbox = SpeakerConfig(
|
|
id="test-speaker-1",
|
|
name="Test Speaker",
|
|
sample_path="/tmp/test_sample.wav",
|
|
tts_backend="chatterbox"
|
|
)
|
|
print("✓ SpeakerConfig for chatterbox backend working")
|
|
|
|
# Test SpeakerConfig validation for higgs backend (should raise error without reference_text)
|
|
try:
|
|
speaker_config_higgs_invalid = SpeakerConfig(
|
|
id="test-speaker-2",
|
|
name="Invalid Higgs Speaker",
|
|
sample_path="/tmp/test_sample.wav",
|
|
tts_backend="higgs"
|
|
)
|
|
speaker_config_higgs_invalid.validate()
|
|
assert False, "Should have raised ValueError for missing reference_text"
|
|
except ValueError as e:
|
|
print("✓ SpeakerConfig validation correctly catches missing reference_text for higgs")
|
|
|
|
# Test valid SpeakerConfig for higgs backend
|
|
speaker_config_higgs_valid = SpeakerConfig(
|
|
id="test-speaker-3",
|
|
name="Valid Higgs Speaker",
|
|
sample_path="/tmp/test_sample.wav",
|
|
reference_text="Hello, this is a test.",
|
|
tts_backend="higgs"
|
|
)
|
|
speaker_config_higgs_valid.validate() # Should not raise
|
|
print("✓ SpeakerConfig for higgs backend with reference_text working")
|
|
|
|
# Test OutputConfig
|
|
output_config = OutputConfig(
|
|
filename_base="test_output",
|
|
output_dir=Path("/tmp"),
|
|
format="wav"
|
|
)
|
|
assert output_config.filename_base == "test_output"
|
|
print("✓ OutputConfig working correctly")
|
|
|
|
# Test TTSRequest
|
|
request = TTSRequest(
|
|
text="Hello world, this is a test.",
|
|
speaker_config=speaker_config_chatterbox,
|
|
parameters=params,
|
|
output_config=output_config
|
|
)
|
|
assert request.text == "Hello world, this is a test."
|
|
assert request.speaker_config.name == "Test Speaker"
|
|
print("✓ TTSRequest working correctly")
|
|
|
|
# Test TTSResponse
|
|
response = TTSResponse(
|
|
output_path=Path("/tmp/output.wav"),
|
|
generated_text="Hello world, this is a test.",
|
|
audio_duration=3.5,
|
|
sampling_rate=22050,
|
|
backend_used="chatterbox"
|
|
)
|
|
assert response.audio_duration == 3.5
|
|
assert response.backend_used == "chatterbox"
|
|
print("✓ TTSResponse working correctly")
|
|
|
|
def test_base_service():
|
|
"""Test abstract base service class"""
|
|
print("\nTesting abstract base service...")
|
|
|
|
# Create a mock implementation
|
|
class MockTTSService(BaseTTSService):
|
|
async def load_model(self):
|
|
self.model = "mock_model_loaded"
|
|
|
|
async def unload_model(self):
|
|
self.model = None
|
|
|
|
async def generate_speech(self, request):
|
|
return TTSResponse(
|
|
output_path=Path("/tmp/mock_output.wav"),
|
|
backend_used=self.backend_name
|
|
)
|
|
|
|
def validate_speaker_config(self, config):
|
|
return True
|
|
|
|
# Test device resolution
|
|
mock_service = MockTTSService(device="auto")
|
|
assert mock_service.device in ["cuda", "mps", "cpu"]
|
|
print(f"✓ Device auto-resolution: {mock_service.device}")
|
|
|
|
# Test backend name extraction
|
|
assert mock_service.backend_name == "mock"
|
|
print("✓ Backend name extraction working")
|
|
|
|
# Test model loading state
|
|
assert not mock_service.is_loaded()
|
|
print("✓ Initial model state check")
|
|
|
|
def test_configuration():
|
|
"""Test configuration values"""
|
|
print("\nTesting configuration...")
|
|
|
|
assert hasattr(config, 'HIGGS_MODEL_PATH')
|
|
assert hasattr(config, 'HIGGS_AUDIO_TOKENIZER_PATH')
|
|
assert hasattr(config, 'DEFAULT_TTS_BACKEND')
|
|
assert hasattr(config, 'TTS_BACKEND_DEFAULTS')
|
|
|
|
print(f"✓ Default TTS backend: {config.DEFAULT_TTS_BACKEND}")
|
|
print(f"✓ Higgs model path: {config.HIGGS_MODEL_PATH}")
|
|
|
|
# Test backend defaults
|
|
assert "chatterbox" in config.TTS_BACKEND_DEFAULTS
|
|
assert "higgs" in config.TTS_BACKEND_DEFAULTS
|
|
assert "temperature" in config.TTS_BACKEND_DEFAULTS["chatterbox"]
|
|
assert "max_new_tokens" in config.TTS_BACKEND_DEFAULTS["higgs"]
|
|
|
|
print("✓ TTS backend defaults configured correctly")
|
|
|
|
def test_error_handling():
|
|
"""Test TTS error classes"""
|
|
print("\nTesting error handling...")
|
|
|
|
# Test TTSError
|
|
try:
|
|
raise TTSError("Test error", "test_backend", "ERROR_001")
|
|
except TTSError as e:
|
|
assert e.backend == "test_backend"
|
|
assert e.error_code == "ERROR_001"
|
|
print("✓ TTSError working correctly")
|
|
|
|
# Test BackendSpecificError inheritance
|
|
from backend.app.services.base_tts_service import BackendSpecificError
|
|
try:
|
|
raise BackendSpecificError("Backend specific error", "higgs")
|
|
except TTSError as e: # Should catch as base class
|
|
assert e.backend == "higgs"
|
|
print("✓ BackendSpecificError inheritance working correctly")
|
|
|
|
def main():
|
|
"""Run all tests"""
|
|
print("=== Phase 1 Implementation Tests ===\n")
|
|
|
|
try:
|
|
test_data_models()
|
|
test_base_service()
|
|
test_configuration()
|
|
test_error_handling()
|
|
|
|
print("\n=== All Phase 1 tests passed! ✓ ===")
|
|
print("\nPhase 1 components ready:")
|
|
print("- TTS data models (TTSRequest, TTSResponse, etc.)")
|
|
print("- Abstract BaseTTSService class")
|
|
print("- Configuration system with Higgs support")
|
|
print("- Error handling framework")
|
|
print("\nReady to proceed to Phase 2: Service Implementation")
|
|
|
|
except Exception as e:
|
|
print(f"\n❌ Test failed: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return 1
|
|
|
|
return 0
|
|
|
|
if __name__ == "__main__":
|
|
exit(main()) |