chatterbox-ui/backend/app/services/higgs_tts_service.py

240 lines
9.6 KiB
Python

"""
Higgs TTS Service Implementation
Implements voice cloning using Higgs Audio v2 system
"""
import base64
import torch
import torchaudio
import numpy as np
from pathlib import Path
from typing import Optional
from .base_tts_service import BaseTTSService, TTSError, BackendSpecificError
from ..models.tts_models import TTSRequest, TTSResponse, SpeakerConfig
# Import configuration
try:
from app.config import TTS_TEMP_OUTPUT_DIR, HIGGS_MODEL_PATH, HIGGS_AUDIO_TOKENIZER_PATH, SPEAKER_DATA_BASE_DIR
except ModuleNotFoundError:
# When imported from scripts at project root
from backend.app.config import TTS_TEMP_OUTPUT_DIR, HIGGS_MODEL_PATH, HIGGS_AUDIO_TOKENIZER_PATH, SPEAKER_DATA_BASE_DIR
# Higgs imports (will be imported dynamically to handle missing dependencies)
try:
from boson_multimodal.serve.serve_engine import HiggsAudioServeEngine, HiggsAudioResponse
from boson_multimodal.data_types import ChatMLSample, Message, AudioContent
HIGGS_AVAILABLE = True
except ImportError as e:
print(f"Warning: Higgs TTS dependencies not available: {e}")
print("To use Higgs TTS, install the boson_multimodal package")
HIGGS_AVAILABLE = False
# Create dummy classes to prevent import errors
class HiggsAudioServeEngine: pass
class HiggsAudioResponse: pass
class ChatMLSample: pass
class Message: pass
class AudioContent: pass
class HiggsTTSService(BaseTTSService):
"""Higgs TTS implementation with voice cloning"""
def __init__(self, device: str = "auto",
model_path: str = None,
audio_tokenizer_path: str = None):
super().__init__(device)
self.backend_name = "higgs"
self.model_path = model_path or HIGGS_MODEL_PATH
self.audio_tokenizer_path = audio_tokenizer_path or HIGGS_AUDIO_TOKENIZER_PATH
self.engine = None
if not HIGGS_AVAILABLE:
print(f"Warning: Higgs TTS backend created but dependencies not available")
async def load_model(self) -> None:
"""Load Higgs TTS model"""
if not HIGGS_AVAILABLE:
raise TTSError(
"Higgs TTS dependencies not available. Install boson_multimodal package.",
"higgs",
"MISSING_DEPENDENCIES"
)
if self.engine is None:
print(f"Loading Higgs TTS model to device: {self.device}...")
try:
self.engine = HiggsAudioServeEngine(
model_name_or_path=self.model_path,
audio_tokenizer_name_or_path=self.audio_tokenizer_path,
device=self.device,
)
self.model = self.engine # Set model for is_loaded() check
print("Higgs TTS model loaded successfully.")
except Exception as e:
raise TTSError(f"Error loading Higgs TTS model: {e}", "higgs")
async def unload_model(self) -> None:
"""Unload Higgs TTS model"""
if self.engine is not None:
print("Unloading Higgs TTS model...")
del self.engine
self.engine = None
self.model = None
self._cleanup_memory()
print("Higgs TTS model unloaded.")
def validate_speaker_config(self, config: SpeakerConfig) -> bool:
"""Validate speaker config for Higgs backend"""
if config.tts_backend != "higgs":
return False
if not config.reference_text:
return False
# Resolve sample path - could be relative to speaker data dir
sample_path = Path(config.sample_path)
if not sample_path.is_absolute():
sample_path = SPEAKER_DATA_BASE_DIR / config.sample_path
if not sample_path.exists():
return False
return True
def _resolve_sample_path(self, config: SpeakerConfig) -> str:
"""Resolve sample path to absolute path"""
sample_path = Path(config.sample_path)
if not sample_path.is_absolute():
sample_path = SPEAKER_DATA_BASE_DIR / config.sample_path
return str(sample_path)
def _encode_audio_to_base64(self, audio_path: str) -> str:
"""Encode audio file to base64 string"""
try:
with open(audio_path, "rb") as audio_file:
audio_base64 = base64.b64encode(audio_file.read()).decode("utf-8")
return audio_base64
except Exception as e:
raise BackendSpecificError(f"Failed to encode audio file {audio_path}: {e}", "higgs")
def _create_chatml_sample(self, request: TTSRequest) -> ChatMLSample:
"""Create ChatML sample for Higgs voice cloning"""
if not HIGGS_AVAILABLE:
raise TTSError("Higgs TTS dependencies not available", "higgs", "MISSING_DEPENDENCIES")
try:
# Get absolute path to audio sample
audio_path = self._resolve_sample_path(request.speaker_config)
# Encode reference audio
reference_audio_b64 = self._encode_audio_to_base64(audio_path)
# Create conversation pattern for voice cloning
messages = [
Message(
role="user",
content=request.speaker_config.reference_text,
),
Message(
role="assistant",
content=AudioContent(
raw_audio=reference_audio_b64,
audio_url="placeholder"
),
),
Message(
role="user",
content=request.text,
),
]
return ChatMLSample(messages=messages)
except Exception as e:
raise BackendSpecificError(f"Error creating ChatML sample: {e}", "higgs")
async def generate_speech(self, request: TTSRequest) -> TTSResponse:
"""Generate speech using Higgs TTS"""
if not HIGGS_AVAILABLE:
raise TTSError("Higgs TTS dependencies not available", "higgs", "MISSING_DEPENDENCIES")
if self.engine is None:
await self.load_model()
# Validate speaker configuration
if not self.validate_speaker_config(request.speaker_config):
raise TTSError(
f"Invalid speaker config for Higgs: {request.speaker_config.name}. "
f"Ensure reference_text is provided and audio sample exists.",
"higgs"
)
# Extract Higgs-specific parameters
backend_params = request.parameters.backend_params
max_new_tokens = backend_params.get("max_new_tokens", 1024)
temperature = request.parameters.temperature
top_p = backend_params.get("top_p", 0.95)
top_k = backend_params.get("top_k", 50)
stop_strings = backend_params.get("stop_strings", ["<|end_of_text|>", "<|eot_id|>"])
# Set up output path
output_dir = request.output_config.output_dir or TTS_TEMP_OUTPUT_DIR
output_dir.mkdir(parents=True, exist_ok=True)
output_path = output_dir / f"{request.output_config.filename_base}.{request.output_config.format}"
print(f"Generating Higgs TTS audio for: \"{request.text[:50]}...\" with speaker: {request.speaker_config.name}")
print(f"Using reference text: \"{request.speaker_config.reference_text[:30]}...\"")
# Create ChatML sample and generate speech
try:
chat_sample = self._create_chatml_sample(request)
response: HiggsAudioResponse = self.engine.generate(
chat_ml_sample=chat_sample,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stop_strings=stop_strings,
)
# Convert numpy audio to tensor and save
if response.audio is not None:
# Handle both 1D and 2D numpy arrays
audio_array = response.audio
if audio_array.ndim == 1:
audio_tensor = torch.from_numpy(audio_array).unsqueeze(0) # Add channel dimension
else:
audio_tensor = torch.from_numpy(audio_array)
torchaudio.save(str(output_path), audio_tensor, response.sampling_rate)
print(f"Higgs TTS audio saved to: {output_path}")
# Calculate audio duration
audio_duration = len(response.audio) / response.sampling_rate
else:
raise BackendSpecificError("No audio generated by Higgs TTS", "higgs")
return TTSResponse(
output_path=output_path,
generated_text=response.generated_text,
audio_duration=audio_duration,
sampling_rate=response.sampling_rate,
backend_used=self.backend_name
)
except Exception as e:
if isinstance(e, TTSError):
raise
raise TTSError(f"Error during Higgs TTS generation: {e}", "higgs")
finally:
self._cleanup_memory()
def get_model_info(self) -> dict:
"""Get information about the loaded Higgs model"""
return {
"backend": self.backend_name,
"model_path": self.model_path,
"audio_tokenizer_path": self.audio_tokenizer_path,
"device": self.device,
"loaded": self.is_loaded(),
"dependencies_available": HIGGS_AVAILABLE
}