240 lines
9.6 KiB
Python
240 lines
9.6 KiB
Python
"""
|
|
Higgs TTS Service Implementation
|
|
Implements voice cloning using Higgs Audio v2 system
|
|
"""
|
|
import base64
|
|
import torch
|
|
import torchaudio
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
from .base_tts_service import BaseTTSService, TTSError, BackendSpecificError
|
|
from ..models.tts_models import TTSRequest, TTSResponse, SpeakerConfig
|
|
|
|
# Import configuration
|
|
try:
|
|
from app.config import TTS_TEMP_OUTPUT_DIR, HIGGS_MODEL_PATH, HIGGS_AUDIO_TOKENIZER_PATH, SPEAKER_DATA_BASE_DIR
|
|
except ModuleNotFoundError:
|
|
# When imported from scripts at project root
|
|
from backend.app.config import TTS_TEMP_OUTPUT_DIR, HIGGS_MODEL_PATH, HIGGS_AUDIO_TOKENIZER_PATH, SPEAKER_DATA_BASE_DIR
|
|
|
|
# Higgs imports (will be imported dynamically to handle missing dependencies)
|
|
try:
|
|
from boson_multimodal.serve.serve_engine import HiggsAudioServeEngine, HiggsAudioResponse
|
|
from boson_multimodal.data_types import ChatMLSample, Message, AudioContent
|
|
HIGGS_AVAILABLE = True
|
|
except ImportError as e:
|
|
print(f"Warning: Higgs TTS dependencies not available: {e}")
|
|
print("To use Higgs TTS, install the boson_multimodal package")
|
|
HIGGS_AVAILABLE = False
|
|
# Create dummy classes to prevent import errors
|
|
class HiggsAudioServeEngine: pass
|
|
class HiggsAudioResponse: pass
|
|
class ChatMLSample: pass
|
|
class Message: pass
|
|
class AudioContent: pass
|
|
|
|
class HiggsTTSService(BaseTTSService):
|
|
"""Higgs TTS implementation with voice cloning"""
|
|
|
|
def __init__(self, device: str = "auto",
|
|
model_path: str = None,
|
|
audio_tokenizer_path: str = None):
|
|
super().__init__(device)
|
|
self.backend_name = "higgs"
|
|
self.model_path = model_path or HIGGS_MODEL_PATH
|
|
self.audio_tokenizer_path = audio_tokenizer_path or HIGGS_AUDIO_TOKENIZER_PATH
|
|
self.engine = None
|
|
|
|
if not HIGGS_AVAILABLE:
|
|
print(f"Warning: Higgs TTS backend created but dependencies not available")
|
|
|
|
async def load_model(self) -> None:
|
|
"""Load Higgs TTS model"""
|
|
if not HIGGS_AVAILABLE:
|
|
raise TTSError(
|
|
"Higgs TTS dependencies not available. Install boson_multimodal package.",
|
|
"higgs",
|
|
"MISSING_DEPENDENCIES"
|
|
)
|
|
|
|
if self.engine is None:
|
|
print(f"Loading Higgs TTS model to device: {self.device}...")
|
|
try:
|
|
self.engine = HiggsAudioServeEngine(
|
|
model_name_or_path=self.model_path,
|
|
audio_tokenizer_name_or_path=self.audio_tokenizer_path,
|
|
device=self.device,
|
|
)
|
|
self.model = self.engine # Set model for is_loaded() check
|
|
print("Higgs TTS model loaded successfully.")
|
|
except Exception as e:
|
|
raise TTSError(f"Error loading Higgs TTS model: {e}", "higgs")
|
|
|
|
async def unload_model(self) -> None:
|
|
"""Unload Higgs TTS model"""
|
|
if self.engine is not None:
|
|
print("Unloading Higgs TTS model...")
|
|
del self.engine
|
|
self.engine = None
|
|
self.model = None
|
|
self._cleanup_memory()
|
|
print("Higgs TTS model unloaded.")
|
|
|
|
def validate_speaker_config(self, config: SpeakerConfig) -> bool:
|
|
"""Validate speaker config for Higgs backend"""
|
|
if config.tts_backend != "higgs":
|
|
return False
|
|
|
|
if not config.reference_text:
|
|
return False
|
|
|
|
# Resolve sample path - could be relative to speaker data dir
|
|
sample_path = Path(config.sample_path)
|
|
if not sample_path.is_absolute():
|
|
sample_path = SPEAKER_DATA_BASE_DIR / config.sample_path
|
|
|
|
if not sample_path.exists():
|
|
return False
|
|
|
|
return True
|
|
|
|
def _resolve_sample_path(self, config: SpeakerConfig) -> str:
|
|
"""Resolve sample path to absolute path"""
|
|
sample_path = Path(config.sample_path)
|
|
if not sample_path.is_absolute():
|
|
sample_path = SPEAKER_DATA_BASE_DIR / config.sample_path
|
|
return str(sample_path)
|
|
|
|
def _encode_audio_to_base64(self, audio_path: str) -> str:
|
|
"""Encode audio file to base64 string"""
|
|
try:
|
|
with open(audio_path, "rb") as audio_file:
|
|
audio_base64 = base64.b64encode(audio_file.read()).decode("utf-8")
|
|
return audio_base64
|
|
except Exception as e:
|
|
raise BackendSpecificError(f"Failed to encode audio file {audio_path}: {e}", "higgs")
|
|
|
|
def _create_chatml_sample(self, request: TTSRequest) -> ChatMLSample:
|
|
"""Create ChatML sample for Higgs voice cloning"""
|
|
if not HIGGS_AVAILABLE:
|
|
raise TTSError("Higgs TTS dependencies not available", "higgs", "MISSING_DEPENDENCIES")
|
|
|
|
try:
|
|
# Get absolute path to audio sample
|
|
audio_path = self._resolve_sample_path(request.speaker_config)
|
|
|
|
# Encode reference audio
|
|
reference_audio_b64 = self._encode_audio_to_base64(audio_path)
|
|
|
|
# Create conversation pattern for voice cloning
|
|
messages = [
|
|
Message(
|
|
role="user",
|
|
content=request.speaker_config.reference_text,
|
|
),
|
|
Message(
|
|
role="assistant",
|
|
content=AudioContent(
|
|
raw_audio=reference_audio_b64,
|
|
audio_url="placeholder"
|
|
),
|
|
),
|
|
Message(
|
|
role="user",
|
|
content=request.text,
|
|
),
|
|
]
|
|
|
|
return ChatMLSample(messages=messages)
|
|
except Exception as e:
|
|
raise BackendSpecificError(f"Error creating ChatML sample: {e}", "higgs")
|
|
|
|
async def generate_speech(self, request: TTSRequest) -> TTSResponse:
|
|
"""Generate speech using Higgs TTS"""
|
|
if not HIGGS_AVAILABLE:
|
|
raise TTSError("Higgs TTS dependencies not available", "higgs", "MISSING_DEPENDENCIES")
|
|
|
|
if self.engine is None:
|
|
await self.load_model()
|
|
|
|
# Validate speaker configuration
|
|
if not self.validate_speaker_config(request.speaker_config):
|
|
raise TTSError(
|
|
f"Invalid speaker config for Higgs: {request.speaker_config.name}. "
|
|
f"Ensure reference_text is provided and audio sample exists.",
|
|
"higgs"
|
|
)
|
|
|
|
# Extract Higgs-specific parameters
|
|
backend_params = request.parameters.backend_params
|
|
max_new_tokens = backend_params.get("max_new_tokens", 1024)
|
|
temperature = request.parameters.temperature
|
|
top_p = backend_params.get("top_p", 0.95)
|
|
top_k = backend_params.get("top_k", 50)
|
|
stop_strings = backend_params.get("stop_strings", ["<|end_of_text|>", "<|eot_id|>"])
|
|
|
|
# Set up output path
|
|
output_dir = request.output_config.output_dir or TTS_TEMP_OUTPUT_DIR
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
output_path = output_dir / f"{request.output_config.filename_base}.{request.output_config.format}"
|
|
|
|
print(f"Generating Higgs TTS audio for: \"{request.text[:50]}...\" with speaker: {request.speaker_config.name}")
|
|
print(f"Using reference text: \"{request.speaker_config.reference_text[:30]}...\"")
|
|
|
|
# Create ChatML sample and generate speech
|
|
try:
|
|
chat_sample = self._create_chatml_sample(request)
|
|
|
|
response: HiggsAudioResponse = self.engine.generate(
|
|
chat_ml_sample=chat_sample,
|
|
max_new_tokens=max_new_tokens,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
top_k=top_k,
|
|
stop_strings=stop_strings,
|
|
)
|
|
|
|
# Convert numpy audio to tensor and save
|
|
if response.audio is not None:
|
|
# Handle both 1D and 2D numpy arrays
|
|
audio_array = response.audio
|
|
if audio_array.ndim == 1:
|
|
audio_tensor = torch.from_numpy(audio_array).unsqueeze(0) # Add channel dimension
|
|
else:
|
|
audio_tensor = torch.from_numpy(audio_array)
|
|
|
|
torchaudio.save(str(output_path), audio_tensor, response.sampling_rate)
|
|
print(f"Higgs TTS audio saved to: {output_path}")
|
|
|
|
# Calculate audio duration
|
|
audio_duration = len(response.audio) / response.sampling_rate
|
|
else:
|
|
raise BackendSpecificError("No audio generated by Higgs TTS", "higgs")
|
|
|
|
return TTSResponse(
|
|
output_path=output_path,
|
|
generated_text=response.generated_text,
|
|
audio_duration=audio_duration,
|
|
sampling_rate=response.sampling_rate,
|
|
backend_used=self.backend_name
|
|
)
|
|
|
|
except Exception as e:
|
|
if isinstance(e, TTSError):
|
|
raise
|
|
raise TTSError(f"Error during Higgs TTS generation: {e}", "higgs")
|
|
finally:
|
|
self._cleanup_memory()
|
|
|
|
def get_model_info(self) -> dict:
|
|
"""Get information about the loaded Higgs model"""
|
|
return {
|
|
"backend": self.backend_name,
|
|
"model_path": self.model_path,
|
|
"audio_tokenizer_path": self.audio_tokenizer_path,
|
|
"device": self.device,
|
|
"loaded": self.is_loaded(),
|
|
"dependencies_available": HIGGS_AVAILABLE
|
|
} |