246 lines
9.3 KiB
Python
246 lines
9.3 KiB
Python
"""
|
|
Simplified Higgs TTS Service
|
|
Direct integration with Higgs TTS for voice cloning
|
|
"""
|
|
import asyncio
|
|
import os
|
|
import uuid
|
|
from pathlib import Path
|
|
from typing import Optional, Dict, Any
|
|
import base64
|
|
|
|
# Graceful import of Higgs TTS
|
|
try:
|
|
from boson_multimodal.serve.serve_engine import HiggsAudioServeEngine
|
|
from boson_multimodal.data_types import ChatMLSample, AudioContent, Message
|
|
HIGGS_AVAILABLE = True
|
|
print("✅ Higgs TTS dependencies available")
|
|
except ImportError as e:
|
|
HIGGS_AVAILABLE = False
|
|
print(f"⚠️ Higgs TTS not available: {e}")
|
|
print("To use Higgs TTS, install: pip install boson-multimodal")
|
|
|
|
|
|
class TTSService:
|
|
"""Simplified TTS Service using Higgs TTS"""
|
|
|
|
def __init__(self, device: str = "auto"):
|
|
self.device = self._resolve_device(device)
|
|
self.model = None
|
|
self.is_loaded = False
|
|
|
|
def _resolve_device(self, device: str) -> str:
|
|
"""Resolve device string to actual device"""
|
|
if device == "auto":
|
|
try:
|
|
import torch
|
|
if torch.cuda.is_available():
|
|
return "cuda"
|
|
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
|
return "mps"
|
|
else:
|
|
return "cpu"
|
|
except ImportError:
|
|
return "cpu"
|
|
return device
|
|
|
|
def load_model(self):
|
|
"""Load the Higgs TTS model"""
|
|
if not HIGGS_AVAILABLE:
|
|
raise RuntimeError("Higgs TTS dependencies not available. Install boson-multimodal package.")
|
|
|
|
if self.is_loaded:
|
|
return
|
|
|
|
print(f"Loading Higgs TTS model on device: {self.device}")
|
|
|
|
try:
|
|
# Initialize Higgs serve engine
|
|
self.model = HiggsAudioServeEngine(
|
|
model_name_or_path="bosonai/higgs-audio-v2-generation-3B-base",
|
|
audio_tokenizer_name_or_path="bosonai/higgs-audio-v2-tokenizer",
|
|
device=self.device
|
|
)
|
|
self.is_loaded = True
|
|
print("✅ Higgs TTS model loaded successfully")
|
|
|
|
except Exception as e:
|
|
print(f"❌ Failed to load Higgs TTS model: {e}")
|
|
raise RuntimeError(f"Failed to load Higgs TTS model: {e}")
|
|
|
|
def unload_model(self):
|
|
"""Unload the TTS model to free memory"""
|
|
if self.model is not None:
|
|
del self.model
|
|
self.model = None
|
|
self.is_loaded = False
|
|
|
|
# Clear GPU cache if available
|
|
try:
|
|
import torch
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
|
torch.mps.empty_cache()
|
|
except ImportError:
|
|
pass
|
|
|
|
print("✅ Higgs TTS model unloaded")
|
|
|
|
def _audio_file_to_base64(self, audio_path: str) -> str:
|
|
"""Convert audio file to base64 string"""
|
|
with open(audio_path, 'rb') as audio_file:
|
|
return base64.b64encode(audio_file.read()).decode('utf-8')
|
|
|
|
def _create_chatml_sample(self, text: str, reference_text: str, reference_audio_path: str, description: str = None) -> 'ChatMLSample':
|
|
"""Create ChatML sample for Higgs TTS voice cloning"""
|
|
if not HIGGS_AVAILABLE:
|
|
raise RuntimeError("ChatML dependencies not available")
|
|
|
|
# Encode reference audio to base64
|
|
audio_base64 = self._audio_file_to_base64(reference_audio_path)
|
|
|
|
# Create system prompt with scene description (following Higgs pattern)
|
|
# Use provided description or default to natural style
|
|
speaker_style = description if description and description.strip() else "natural;clear voice;moderate pitch"
|
|
scene_desc = f"<|scene_desc_start|>\nSPEAKER0: {speaker_style}\n<|scene_desc_end|>"
|
|
system_prompt = f"Generate audio following instruction.\n\n{scene_desc}"
|
|
|
|
# Create messages following the voice cloning pattern from Higgs examples
|
|
messages = [
|
|
# System message with scene description
|
|
Message(role="system", content=system_prompt),
|
|
# User provides reference text
|
|
Message(role="user", content=reference_text),
|
|
# Assistant provides reference audio
|
|
Message(
|
|
role="assistant",
|
|
content=AudioContent(
|
|
raw_audio=audio_base64,
|
|
audio_url="placeholder"
|
|
)
|
|
),
|
|
# User requests target text
|
|
Message(role="user", content=text)
|
|
]
|
|
|
|
# Create ChatML sample
|
|
return ChatMLSample(messages=messages)
|
|
|
|
async def generate_speech(
|
|
self,
|
|
text: str,
|
|
speaker_sample_path: str,
|
|
reference_text: str,
|
|
output_filename_base: str,
|
|
output_dir: Path,
|
|
description: str = None,
|
|
temperature: float = 0.9,
|
|
max_new_tokens: int = 1024,
|
|
top_p: float = 0.95,
|
|
top_k: int = 50,
|
|
**kwargs
|
|
) -> Path:
|
|
"""
|
|
Generate speech using Higgs TTS voice cloning
|
|
|
|
Args:
|
|
text: Text to synthesize
|
|
speaker_sample_path: Path to speaker audio sample
|
|
reference_text: Text corresponding to the audio sample
|
|
output_filename_base: Base name for output file
|
|
output_dir: Directory for output files
|
|
temperature: Sampling temperature
|
|
max_new_tokens: Maximum tokens to generate
|
|
top_p: Nucleus sampling threshold
|
|
top_k: Top-k sampling limit
|
|
|
|
Returns:
|
|
Path to generated audio file
|
|
"""
|
|
if not HIGGS_AVAILABLE:
|
|
raise RuntimeError("Higgs TTS not available. Install boson-multimodal package.")
|
|
|
|
if not self.is_loaded:
|
|
self.load_model()
|
|
|
|
# Ensure output directory exists
|
|
output_dir = Path(output_dir)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Create output filename
|
|
output_filename = f"{output_filename_base}_{uuid.uuid4().hex[:8]}.wav"
|
|
output_path = output_dir / output_filename
|
|
|
|
try:
|
|
print(f"Generating speech: '{text[:50]}...'")
|
|
print(f"Using voice sample: {speaker_sample_path}")
|
|
print(f"Reference text: '{reference_text[:50]}...'")
|
|
|
|
# Validate audio file exists
|
|
if not os.path.exists(speaker_sample_path):
|
|
raise FileNotFoundError(f"Speaker audio file not found: {speaker_sample_path}")
|
|
|
|
file_size = os.path.getsize(speaker_sample_path)
|
|
if file_size == 0:
|
|
raise ValueError(f"Speaker audio file is empty: {speaker_sample_path}")
|
|
|
|
print(f"Audio file validated: {file_size} bytes")
|
|
|
|
# Create ChatML sample for Higgs TTS
|
|
chatml_sample = self._create_chatml_sample(text, reference_text, speaker_sample_path, description)
|
|
|
|
# Generate audio using Higgs TTS
|
|
result = await asyncio.get_event_loop().run_in_executor(
|
|
None,
|
|
self._generate_sync,
|
|
chatml_sample,
|
|
str(output_path),
|
|
temperature,
|
|
max_new_tokens,
|
|
top_p,
|
|
top_k
|
|
)
|
|
|
|
if not output_path.exists():
|
|
raise RuntimeError(f"Audio generation failed - output file not created: {output_path}")
|
|
|
|
print(f"✅ Speech generated: {output_path}")
|
|
return output_path
|
|
|
|
except Exception as e:
|
|
print(f"❌ Speech generation failed: {e}")
|
|
raise RuntimeError(f"Failed to generate speech: {e}")
|
|
|
|
def _generate_sync(self, chatml_sample: 'ChatMLSample', output_path: str, temperature: float,
|
|
max_new_tokens: int, top_p: float, top_k: int) -> None:
|
|
"""Synchronous generation wrapper for thread execution"""
|
|
try:
|
|
# Generate with Higgs TTS using the correct API
|
|
response = self.model.generate(
|
|
chat_ml_sample=chatml_sample,
|
|
temperature=temperature,
|
|
max_new_tokens=max_new_tokens,
|
|
top_p=top_p,
|
|
top_k=top_k,
|
|
force_audio_gen=True # Ensure audio generation
|
|
)
|
|
|
|
# Save the generated audio
|
|
if response.audio is not None:
|
|
import torchaudio
|
|
import torch
|
|
|
|
# Convert numpy array to torch tensor if needed
|
|
if hasattr(response.audio, 'shape'):
|
|
audio_tensor = torch.from_numpy(response.audio).unsqueeze(0)
|
|
else:
|
|
audio_tensor = response.audio
|
|
|
|
sample_rate = response.sampling_rate or 24000
|
|
torchaudio.save(output_path, audio_tensor, sample_rate)
|
|
else:
|
|
raise RuntimeError("No audio output generated")
|
|
|
|
except Exception as e:
|
|
raise RuntimeError(f"Higgs TTS generation failed: {e}") |