chatterbox-ui/backend/app/services/tts_service.py

246 lines
9.3 KiB
Python

"""
Simplified Higgs TTS Service
Direct integration with Higgs TTS for voice cloning
"""
import asyncio
import os
import uuid
from pathlib import Path
from typing import Optional, Dict, Any
import base64
# Graceful import of Higgs TTS
try:
from boson_multimodal.serve.serve_engine import HiggsAudioServeEngine
from boson_multimodal.data_types import ChatMLSample, AudioContent, Message
HIGGS_AVAILABLE = True
print("✅ Higgs TTS dependencies available")
except ImportError as e:
HIGGS_AVAILABLE = False
print(f"⚠️ Higgs TTS not available: {e}")
print("To use Higgs TTS, install: pip install boson-multimodal")
class TTSService:
"""Simplified TTS Service using Higgs TTS"""
def __init__(self, device: str = "auto"):
self.device = self._resolve_device(device)
self.model = None
self.is_loaded = False
def _resolve_device(self, device: str) -> str:
"""Resolve device string to actual device"""
if device == "auto":
try:
import torch
if torch.cuda.is_available():
return "cuda"
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return "mps"
else:
return "cpu"
except ImportError:
return "cpu"
return device
def load_model(self):
"""Load the Higgs TTS model"""
if not HIGGS_AVAILABLE:
raise RuntimeError("Higgs TTS dependencies not available. Install boson-multimodal package.")
if self.is_loaded:
return
print(f"Loading Higgs TTS model on device: {self.device}")
try:
# Initialize Higgs serve engine
self.model = HiggsAudioServeEngine(
model_name_or_path="bosonai/higgs-audio-v2-generation-3B-base",
audio_tokenizer_name_or_path="bosonai/higgs-audio-v2-tokenizer",
device=self.device
)
self.is_loaded = True
print("✅ Higgs TTS model loaded successfully")
except Exception as e:
print(f"❌ Failed to load Higgs TTS model: {e}")
raise RuntimeError(f"Failed to load Higgs TTS model: {e}")
def unload_model(self):
"""Unload the TTS model to free memory"""
if self.model is not None:
del self.model
self.model = None
self.is_loaded = False
# Clear GPU cache if available
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
torch.mps.empty_cache()
except ImportError:
pass
print("✅ Higgs TTS model unloaded")
def _audio_file_to_base64(self, audio_path: str) -> str:
"""Convert audio file to base64 string"""
with open(audio_path, 'rb') as audio_file:
return base64.b64encode(audio_file.read()).decode('utf-8')
def _create_chatml_sample(self, text: str, reference_text: str, reference_audio_path: str, description: str = None) -> 'ChatMLSample':
"""Create ChatML sample for Higgs TTS voice cloning"""
if not HIGGS_AVAILABLE:
raise RuntimeError("ChatML dependencies not available")
# Encode reference audio to base64
audio_base64 = self._audio_file_to_base64(reference_audio_path)
# Create system prompt with scene description (following Higgs pattern)
# Use provided description or default to natural style
speaker_style = description if description and description.strip() else "natural;clear voice;moderate pitch"
scene_desc = f"<|scene_desc_start|>\nSPEAKER0: {speaker_style}\n<|scene_desc_end|>"
system_prompt = f"Generate audio following instruction.\n\n{scene_desc}"
# Create messages following the voice cloning pattern from Higgs examples
messages = [
# System message with scene description
Message(role="system", content=system_prompt),
# User provides reference text
Message(role="user", content=reference_text),
# Assistant provides reference audio
Message(
role="assistant",
content=AudioContent(
raw_audio=audio_base64,
audio_url="placeholder"
)
),
# User requests target text
Message(role="user", content=text)
]
# Create ChatML sample
return ChatMLSample(messages=messages)
async def generate_speech(
self,
text: str,
speaker_sample_path: str,
reference_text: str,
output_filename_base: str,
output_dir: Path,
description: str = None,
temperature: float = 0.9,
max_new_tokens: int = 1024,
top_p: float = 0.95,
top_k: int = 50,
**kwargs
) -> Path:
"""
Generate speech using Higgs TTS voice cloning
Args:
text: Text to synthesize
speaker_sample_path: Path to speaker audio sample
reference_text: Text corresponding to the audio sample
output_filename_base: Base name for output file
output_dir: Directory for output files
temperature: Sampling temperature
max_new_tokens: Maximum tokens to generate
top_p: Nucleus sampling threshold
top_k: Top-k sampling limit
Returns:
Path to generated audio file
"""
if not HIGGS_AVAILABLE:
raise RuntimeError("Higgs TTS not available. Install boson-multimodal package.")
if not self.is_loaded:
self.load_model()
# Ensure output directory exists
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Create output filename
output_filename = f"{output_filename_base}_{uuid.uuid4().hex[:8]}.wav"
output_path = output_dir / output_filename
try:
print(f"Generating speech: '{text[:50]}...'")
print(f"Using voice sample: {speaker_sample_path}")
print(f"Reference text: '{reference_text[:50]}...'")
# Validate audio file exists
if not os.path.exists(speaker_sample_path):
raise FileNotFoundError(f"Speaker audio file not found: {speaker_sample_path}")
file_size = os.path.getsize(speaker_sample_path)
if file_size == 0:
raise ValueError(f"Speaker audio file is empty: {speaker_sample_path}")
print(f"Audio file validated: {file_size} bytes")
# Create ChatML sample for Higgs TTS
chatml_sample = self._create_chatml_sample(text, reference_text, speaker_sample_path, description)
# Generate audio using Higgs TTS
result = await asyncio.get_event_loop().run_in_executor(
None,
self._generate_sync,
chatml_sample,
str(output_path),
temperature,
max_new_tokens,
top_p,
top_k
)
if not output_path.exists():
raise RuntimeError(f"Audio generation failed - output file not created: {output_path}")
print(f"✅ Speech generated: {output_path}")
return output_path
except Exception as e:
print(f"❌ Speech generation failed: {e}")
raise RuntimeError(f"Failed to generate speech: {e}")
def _generate_sync(self, chatml_sample: 'ChatMLSample', output_path: str, temperature: float,
max_new_tokens: int, top_p: float, top_k: int) -> None:
"""Synchronous generation wrapper for thread execution"""
try:
# Generate with Higgs TTS using the correct API
response = self.model.generate(
chat_ml_sample=chatml_sample,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
top_k=top_k,
force_audio_gen=True # Ensure audio generation
)
# Save the generated audio
if response.audio is not None:
import torchaudio
import torch
# Convert numpy array to torch tensor if needed
if hasattr(response.audio, 'shape'):
audio_tensor = torch.from_numpy(response.audio).unsqueeze(0)
else:
audio_tensor = response.audio
sample_rate = response.sampling_rate or 24000
torchaudio.save(output_path, audio_tensor, sample_rate)
else:
raise RuntimeError("No audio output generated")
except Exception as e:
raise RuntimeError(f"Higgs TTS generation failed: {e}")