# Higgs TTS Integration Implementation Plan ## Overview This document outlines the comprehensive plan for refactoring the chatterbox-ui backend to support the Higgs TTS system alongside the existing ChatterboxTTS system. The plan incorporates code review recommendations and addresses the key architectural challenges identified. ## Key Differences Between TTS Systems ### ChatterboxTTS - Uses `ChatterboxTTS.from_pretrained()` and `.generate()` method - Simple audio prompt path for voice cloning - Parameters: `exaggeration`, `cfg_weight`, `temperature` - Returns torch tensors ### Higgs TTS - Uses `HiggsAudioServeEngine` with separate model and tokenizer paths - Voice cloning requires base64-encoded audio + reference text in ChatML format - Parameters: `max_new_tokens`, `temperature`, `top_p`, `top_k` - Returns numpy arrays via `HiggsAudioResponse` - Conversation-style interface with user/assistant message pattern ## Implementation Plan ### Phase 1: Foundation and Abstraction Layer #### 1.1 Create Abstract Base Classes and Data Models **File: `backend/app/models/tts_models.py`** ```python from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Dict, Any, Optional from pathlib import Path @dataclass class TTSParameters: """Common TTS parameters with backend-specific extensions""" temperature: float = 0.8 backend_params: Dict[str, Any] = field(default_factory=dict) @dataclass class SpeakerConfig: """Enhanced speaker configuration""" id: str name: str sample_path: str reference_text: Optional[str] = None tts_backend: str = "chatterbox" def validate(self): """Validate speaker configuration based on backend""" if self.tts_backend == "higgs" and not self.reference_text: raise ValueError(f"reference_text required for Higgs backend speaker: {self.name}") @dataclass class OutputConfig: """Output configuration for TTS generation""" filename_base: str output_dir: Optional[Path] = None format: str = "wav" @dataclass class TTSRequest: """Unified TTS request structure""" text: str speaker_config: SpeakerConfig parameters: TTSParameters output_config: OutputConfig @dataclass class TTSResponse: """Unified TTS response structure""" output_path: Path generated_text: Optional[str] = None audio_duration: Optional[float] = None sampling_rate: Optional[int] = None backend_used: str = "" ``` **File: `backend/app/services/base_tts_service.py`** ```python from abc import ABC, abstractmethod from typing import Optional import torch import gc from pathlib import Path from ..models.tts_models import TTSRequest, TTSResponse from ..models.speaker_models import SpeakerConfig class TTSError(Exception): """Base exception for TTS operations""" def __init__(self, message: str, backend: str, error_code: str = None): super().__init__(message) self.backend = backend self.error_code = error_code class BackendSpecificError(TTSError): """Backend-specific TTS errors""" pass class BaseTTSService(ABC): """Abstract base class for TTS services""" def __init__(self, device: str = "auto"): self.device = self._resolve_device(device) self.model = None self.backend_name = self.__class__.__name__.replace('TTSService', '').lower() def _resolve_device(self, device: str) -> str: """Resolve device string to actual device""" if device == "auto": if torch.cuda.is_available(): return "cuda" elif torch.backends.mps.is_available(): return "mps" return "cpu" return device @abstractmethod async def load_model(self) -> None: """Load the TTS model""" pass @abstractmethod async def unload_model(self) -> None: """Unload the TTS model and free memory""" pass @abstractmethod async def generate_speech(self, request: TTSRequest) -> TTSResponse: """Generate speech from TTS request""" pass @abstractmethod def validate_speaker_config(self, config: SpeakerConfig) -> bool: """Validate speaker configuration for this backend""" pass def _cleanup_memory(self): """Common memory cleanup routine""" gc.collect() if self.device == "cuda": torch.cuda.empty_cache() elif self.device == "mps": if hasattr(torch.mps, "empty_cache"): torch.mps.empty_cache() ``` #### 1.2 Configuration System Updates **File: `backend/app/config.py` (additions)** ```python # Higgs TTS Configuration HIGGS_MODEL_PATH = os.getenv("HIGGS_MODEL_PATH", "bosonai/higgs-audio-v2-generation-3B-base") HIGGS_AUDIO_TOKENIZER_PATH = os.getenv("HIGGS_AUDIO_TOKENIZER_PATH", "bosonai/higgs-audio-v2-tokenizer") DEFAULT_TTS_BACKEND = os.getenv("DEFAULT_TTS_BACKEND", "chatterbox") # Backend-specific parameter defaults TTS_BACKEND_DEFAULTS = { "chatterbox": { "exaggeration": 0.5, "cfg_weight": 0.5, "temperature": 0.8 }, "higgs": { "max_new_tokens": 1024, "temperature": 0.9, "top_p": 0.95, "top_k": 50, "stop_strings": ["<|end_of_text|>", "<|eot_id|>"] } } ``` ### Phase 2: Service Implementation #### 2.1 Refactor Existing ChatterboxTTS Service **File: `backend/app/services/chatterbox_tts_service.py`** ```python import torch import torchaudio from pathlib import Path from typing import Optional from .base_tts_service import BaseTTSService, TTSError from ..models.tts_models import TTSRequest, TTSResponse, SpeakerConfig from ..config import TTS_TEMP_OUTPUT_DIR # Import existing chatterbox functionality from chatterbox.tts import ChatterboxTTS class ChatterboxTTSService(BaseTTSService): """Chatterbox TTS implementation""" def __init__(self, device: str = "auto"): super().__init__(device) self.backend_name = "chatterbox" async def load_model(self) -> None: """Load ChatterboxTTS model with device mapping""" if self.model is None: print(f"Loading ChatterboxTTS model to device: {self.device}...") try: self.model = self._safe_load_chatterbox_tts(self.device) print("ChatterboxTTS model loaded successfully.") except Exception as e: raise TTSError(f"Error loading ChatterboxTTS model: {e}", "chatterbox") async def unload_model(self) -> None: """Unload ChatterboxTTS model""" if self.model is not None: print("Unloading ChatterboxTTS model...") del self.model self.model = None self._cleanup_memory() print("ChatterboxTTS model unloaded.") def validate_speaker_config(self, config: SpeakerConfig) -> bool: """Validate speaker config for Chatterbox backend""" if config.tts_backend != "chatterbox": return False sample_path = Path(config.sample_path) if not sample_path.exists(): return False return True async def generate_speech(self, request: TTSRequest) -> TTSResponse: """Generate speech using ChatterboxTTS""" if self.model is None: await self.load_model() # Validate speaker configuration if not self.validate_speaker_config(request.speaker_config): raise TTSError( f"Invalid speaker config for Chatterbox: {request.speaker_config.name}", "chatterbox" ) # Extract Chatterbox-specific parameters backend_params = request.parameters.backend_params exaggeration = backend_params.get("exaggeration", 0.5) cfg_weight = backend_params.get("cfg_weight", 0.5) temperature = request.parameters.temperature # Set up output path output_dir = request.output_config.output_dir or TTS_TEMP_OUTPUT_DIR output_dir.mkdir(parents=True, exist_ok=True) output_path = output_dir / f"{request.output_config.filename_base}.wav" # Generate speech try: with torch.no_grad(): wav = self.model.generate( text=request.text, audio_prompt_path=request.speaker_config.sample_path, exaggeration=exaggeration, cfg_weight=cfg_weight, temperature=temperature, ) torchaudio.save(str(output_path), wav, self.model.sr) # Calculate audio duration audio_duration = wav.shape[1] / self.model.sr if wav is not None else None return TTSResponse( output_path=output_path, audio_duration=audio_duration, sampling_rate=self.model.sr, backend_used=self.backend_name ) except Exception as e: raise TTSError(f"Error during Chatterbox TTS generation: {e}", "chatterbox") finally: if 'wav' in locals(): del wav self._cleanup_memory() def _safe_load_chatterbox_tts(self, device): """Safe loading with device mapping (existing implementation)""" # ... existing implementation from current tts_service.py pass ``` #### 2.2 Create Higgs TTS Service **File: `backend/app/services/higgs_tts_service.py`** ```python import base64 import torch import torchaudio import numpy as np from pathlib import Path from typing import Optional from .base_tts_service import BaseTTSService, TTSError, BackendSpecificError from ..models.tts_models import TTSRequest, TTSResponse, SpeakerConfig from ..config import TTS_TEMP_OUTPUT_DIR, HIGGS_MODEL_PATH, HIGGS_AUDIO_TOKENIZER_PATH # Higgs imports from boson_multimodal.serve.serve_engine import HiggsAudioServeEngine, HiggsAudioResponse from boson_multimodal.data_types import ChatMLSample, Message, AudioContent class HiggsTTSService(BaseTTSService): """Higgs TTS implementation""" def __init__(self, device: str = "auto", model_path: str = None, audio_tokenizer_path: str = None): super().__init__(device) self.backend_name = "higgs" self.model_path = model_path or HIGGS_MODEL_PATH self.audio_tokenizer_path = audio_tokenizer_path or HIGGS_AUDIO_TOKENIZER_PATH self.engine = None async def load_model(self) -> None: """Load Higgs TTS model""" if self.engine is None: print(f"Loading Higgs TTS model to device: {self.device}...") try: self.engine = HiggsAudioServeEngine( model_name_or_path=self.model_path, audio_tokenizer_name_or_path=self.audio_tokenizer_path, device=self.device, ) print("Higgs TTS model loaded successfully.") except Exception as e: raise TTSError(f"Error loading Higgs TTS model: {e}", "higgs") async def unload_model(self) -> None: """Unload Higgs TTS model""" if self.engine is not None: print("Unloading Higgs TTS model...") del self.engine self.engine = None self._cleanup_memory() print("Higgs TTS model unloaded.") def validate_speaker_config(self, config: SpeakerConfig) -> bool: """Validate speaker config for Higgs backend""" if config.tts_backend != "higgs": return False if not config.reference_text: return False sample_path = Path(config.sample_path) if not sample_path.exists(): return False return True def _encode_audio_to_base64(self, audio_path: str) -> str: """Encode audio file to base64 string""" try: with open(audio_path, "rb") as audio_file: audio_base64 = base64.b64encode(audio_file.read()).decode("utf-8") return audio_base64 except Exception as e: raise BackendSpecificError(f"Failed to encode audio file {audio_path}: {e}", "higgs") def _create_chatml_sample(self, request: TTSRequest) -> ChatMLSample: """Create ChatML sample for Higgs voice cloning""" try: # Encode reference audio reference_audio_b64 = self._encode_audio_to_base64(request.speaker_config.sample_path) # Create conversation pattern for voice cloning messages = [ Message( role="user", content=request.speaker_config.reference_text, ), Message( role="assistant", content=AudioContent( raw_audio=reference_audio_b64, audio_url="placeholder" ), ), Message( role="user", content=request.text, ), ] return ChatMLSample(messages=messages) except Exception as e: raise BackendSpecificError(f"Error creating ChatML sample: {e}", "higgs") async def generate_speech(self, request: TTSRequest) -> TTSResponse: """Generate speech using Higgs TTS""" if self.engine is None: await self.load_model() # Validate speaker configuration if not self.validate_speaker_config(request.speaker_config): raise TTSError( f"Invalid speaker config for Higgs: {request.speaker_config.name}", "higgs" ) # Extract Higgs-specific parameters backend_params = request.parameters.backend_params max_new_tokens = backend_params.get("max_new_tokens", 1024) temperature = request.parameters.temperature top_p = backend_params.get("top_p", 0.95) top_k = backend_params.get("top_k", 50) stop_strings = backend_params.get("stop_strings", ["<|end_of_text|>", "<|eot_id|>"]) # Set up output path output_dir = request.output_config.output_dir or TTS_TEMP_OUTPUT_DIR output_dir.mkdir(parents=True, exist_ok=True) output_path = output_dir / f"{request.output_config.filename_base}.wav" # Create ChatML sample and generate speech try: chat_sample = self._create_chatml_sample(request) response: HiggsAudioResponse = self.engine.generate( chat_ml_sample=chat_sample, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, stop_strings=stop_strings, ) # Convert numpy audio to tensor and save if response.audio is not None: audio_tensor = torch.from_numpy(response.audio).unsqueeze(0) torchaudio.save(str(output_path), audio_tensor, response.sampling_rate) # Calculate audio duration audio_duration = len(response.audio) / response.sampling_rate else: raise BackendSpecificError("No audio generated by Higgs TTS", "higgs") return TTSResponse( output_path=output_path, generated_text=response.generated_text, audio_duration=audio_duration, sampling_rate=response.sampling_rate, backend_used=self.backend_name ) except Exception as e: if isinstance(e, TTSError): raise raise TTSError(f"Error during Higgs TTS generation: {e}", "higgs") finally: self._cleanup_memory() ``` #### 2.3 Create TTS Service Factory **File: `backend/app/services/tts_factory.py`** ```python from typing import Dict, Type from .base_tts_service import BaseTTSService, TTSError from .chatterbox_tts_service import ChatterboxTTSService from .higgs_tts_service import HiggsTTSService from ..config import DEFAULT_TTS_BACKEND class TTSServiceFactory: """Factory for creating TTS service instances""" _services: Dict[str, Type[BaseTTSService]] = { "chatterbox": ChatterboxTTSService, "higgs": HiggsTTSService } _instances: Dict[str, BaseTTSService] = {} @classmethod def register_service(cls, name: str, service_class: Type[BaseTTSService]): """Register a new TTS service type""" cls._services[name] = service_class @classmethod def create_service(cls, backend: str = None, device: str = "auto", singleton: bool = True) -> BaseTTSService: """Create or retrieve TTS service instance""" backend = backend or DEFAULT_TTS_BACKEND if backend not in cls._services: available = ", ".join(cls._services.keys()) raise TTSError(f"Unknown TTS backend: {backend}. Available: {available}", backend) # Return singleton instance if requested and exists if singleton and backend in cls._instances: return cls._instances[backend] # Create new instance service_class = cls._services[backend] instance = service_class(device=device) if singleton: cls._instances[backend] = instance return instance @classmethod def get_available_backends(cls) -> list: """Get list of available TTS backends""" return list(cls._services.keys()) @classmethod async def cleanup_all(cls): """Cleanup all service instances""" for service in cls._instances.values(): try: await service.unload_model() except Exception as e: print(f"Error unloading {service.backend_name}: {e}") cls._instances.clear() ``` ### Phase 3: Enhanced Data Models and Validation #### 3.1 Update Speaker Model **File: `backend/app/models/speaker_models.py` (updates)** ```python from pydantic import BaseModel, validator from typing import Optional class SpeakerBase(BaseModel): name: str reference_text: Optional[str] = None tts_backend: str = "chatterbox" class SpeakerCreate(SpeakerBase): """Model for speaker creation requests""" pass class Speaker(SpeakerBase): """Complete speaker model with ID and sample path""" id: str sample_path: Optional[str] = None @validator('reference_text') def validate_reference_text_for_higgs(cls, v, values): """Validate that Higgs backend speakers have reference text""" if values.get('tts_backend') == 'higgs' and not v: raise ValueError("reference_text is required for Higgs TTS backend") return v @validator('tts_backend') def validate_backend(cls, v): """Validate TTS backend selection""" valid_backends = ["chatterbox", "higgs"] if v not in valid_backends: raise ValueError(f"Invalid TTS backend: {v}. Must be one of {valid_backends}") return v class Config: from_attributes = True ``` #### 3.2 Update Speaker Service **File: `backend/app/services/speaker_service.py` (key updates)** ```python # Add to SpeakerManagementService class async def add_speaker(self, name: str, audio_file: UploadFile, reference_text: str = None, tts_backend: str = "chatterbox") -> Speaker: """Enhanced speaker creation with TTS backend support""" speaker_id = str(uuid.uuid4()) # Validate backend-specific requirements if tts_backend == "higgs" and not reference_text: raise HTTPException( status_code=400, detail="reference_text is required for Higgs TTS backend" ) # ... existing audio processing code ... new_speaker_data = { "name": name, "sample_path": str(sample_path.relative_to(config.SPEAKER_DATA_BASE_DIR)), "reference_text": reference_text, "tts_backend": tts_backend } self.speakers_data[speaker_id] = new_speaker_data self._save_speakers_data() return Speaker(id=speaker_id, **new_speaker_data) def migrate_existing_speakers(self): """Migration utility for existing speakers""" updated = False for speaker_id, speaker_data in self.speakers_data.items(): if "tts_backend" not in speaker_data: speaker_data["tts_backend"] = "chatterbox" updated = True if "reference_text" not in speaker_data: speaker_data["reference_text"] = None updated = True if updated: self._save_speakers_data() print("Migrated existing speakers to new format") ``` ### Phase 4: Service Integration #### 4.1 Update Dialog Processor Service **File: `backend/app/services/dialog_processor_service.py` (key updates)** ```python from .tts_factory import TTSServiceFactory from ..models.tts_models import TTSRequest, TTSParameters from ..config import TTS_BACKEND_DEFAULTS class DialogProcessorService: def __init__(self): # Remove direct TTS service instantiation # Services will be created via factory as needed pass async def process_dialog_item(self, dialog_item, speaker_info, output_dir, segment_index): """Process individual dialog item with backend selection""" # Determine TTS backend from speaker info tts_backend = speaker_info.get("tts_backend", "chatterbox") # Get appropriate TTS service tts_service = TTSServiceFactory.create_service(tts_backend) # Build parameters for the backend base_params = TTS_BACKEND_DEFAULTS.get(tts_backend, {}) parameters = TTSParameters( temperature=base_params.get("temperature", 0.8), backend_params=base_params ) # Create speaker config speaker_config = SpeakerConfig( id=speaker_info["id"], name=speaker_info["name"], sample_path=speaker_info["sample_path"], reference_text=speaker_info.get("reference_text"), tts_backend=tts_backend ) # Create TTS request request = TTSRequest( text=dialog_item["text"], speaker_config=speaker_config, parameters=parameters, output_config=OutputConfig( filename_base=f"dialog_line_{segment_index}_spk_{speaker_info['id']}", output_dir=Path(output_dir) ) ) # Generate speech try: response = await tts_service.generate_speech(request) return response.output_path except Exception as e: print(f"Error generating speech with {tts_backend}: {e}") raise ``` ### Phase 5: API and Frontend Updates #### 5.1 Update API Endpoints **File: `backend/app/api/endpoints/speakers.py` (additions)** ```python from fastapi import Form @router.post("/", response_model=Speaker) async def create_speaker( name: str = Form(...), tts_backend: str = Form("chatterbox"), reference_text: str = Form(None), audio_file: UploadFile = File(...) ): """Enhanced speaker creation with TTS backend selection""" speaker_service = SpeakerManagementService() return await speaker_service.add_speaker( name=name, audio_file=audio_file, reference_text=reference_text, tts_backend=tts_backend ) @router.get("/backends") async def get_available_backends(): """Get available TTS backends""" from app.services.tts_factory import TTSServiceFactory return {"backends": TTSServiceFactory.get_available_backends()} ``` #### 5.2 Frontend Updates **File: `frontend/api.js` (additions)** ```javascript // Add TTS backend support to speaker creation async function createSpeaker(name, audioFile, ttsBackend = 'chatterbox', referenceText = null) { const formData = new FormData(); formData.append('name', name); formData.append('audio_file', audioFile); formData.append('tts_backend', ttsBackend); if (referenceText) { formData.append('reference_text', referenceText); } const response = await fetch(`${API_BASE}/speakers/`, { method: 'POST', body: formData }); if (!response.ok) { throw new Error(`Failed to create speaker: ${response.statusText}`); } return response.json(); } async function getAvailableBackends() { const response = await fetch(`${API_BASE}/speakers/backends`); return response.json(); } ``` ### Phase 6: Migration and Configuration #### 6.1 Data Migration Script **File: `backend/migrations/migrate_speakers.py`** ```python #!/usr/bin/env python3 """Migration script for existing speakers to new format""" import sys from pathlib import Path # Add project root to path project_root = Path(__file__).parent.parent.parent sys.path.append(str(project_root)) from backend.app.services.speaker_service import SpeakerManagementService def migrate_speakers(): """Migrate existing speakers to new format""" print("Starting speaker migration...") service = SpeakerManagementService() service.migrate_existing_speakers() print("Migration completed successfully!") # Show current speakers speakers = service.get_speakers() print(f"\nMigrated {len(speakers)} speakers:") for speaker in speakers: print(f" - {speaker.name}: {speaker.tts_backend} backend") if __name__ == "__main__": migrate_speakers() ``` #### 6.2 Environment Configuration Template **File: `.env.template` (additions)** ```bash # Higgs TTS Configuration HIGGS_MODEL_PATH=bosonai/higgs-audio-v2-generation-3B-base HIGGS_AUDIO_TOKENIZER_PATH=bosonai/higgs-audio-v2-tokenizer DEFAULT_TTS_BACKEND=chatterbox # Device Configuration TTS_DEVICE=auto # auto, cpu, cuda, mps ``` ## Implementation Priority ### High Priority (Phase 1-2) 1. ✅ **Abstract base class and data models** - Foundation 2. ✅ **Configuration system updates** - Environment management 3. ✅ **Chatterbox service refactoring** - Maintain existing functionality 4. ✅ **Higgs service implementation** - Core new functionality 5. ✅ **TTS factory pattern** - Service orchestration ### Medium Priority (Phase 3-4) 1. ✅ **Enhanced speaker models** - Data validation and backend support 2. ✅ **Speaker service updates** - CRUD operations with new fields 3. ✅ **Dialog processor integration** - Multi-backend dialog support 4. ⏳ **Error handling framework** - Comprehensive error management ### Lower Priority (Phase 5-6) 1. ⏳ **API endpoint updates** - REST API enhancements 2. ⏳ **Frontend integration** - UI updates for backend selection 3. ⏳ **Migration utilities** - Data migration and cleanup tools 4. ⏳ **Documentation updates** - User guides and API documentation ## Testing Strategy ### Unit Tests - Test each TTS service independently - Validate parameter mapping and conversion - Test error handling scenarios - Mock external dependencies (models, file I/O) ### Integration Tests - Test factory pattern service creation - Test dialog generation with mixed backends - Test speaker management with different backends - Test API endpoints with various request formats ### Performance Tests - Memory usage comparison between backends - Generation speed benchmarks - Stress testing with multiple concurrent requests - Device utilization monitoring (CPU/GPU/MPS) ## Deployment Considerations ### Environment Setup 1. Install Higgs TTS dependencies in existing environment 2. Download required Higgs models to configured paths 3. Update environment variables for backend selection 4. Run migration script for existing speaker data ### Backward Compatibility - Existing speakers default to chatterbox backend - Existing API endpoints remain functional - Frontend gracefully handles missing backend fields - Configuration defaults maintain current behavior ### Performance Monitoring - Track memory usage per backend - Monitor generation times and success rates - Log backend selection and usage statistics - Alert on model loading failures ## Conclusion This implementation plan provides a robust, scalable architecture for supporting multiple TTS backends while maintaining backward compatibility. The abstract base class approach with factory pattern ensures clean separation of concerns and makes it easy to add additional TTS backends in the future. Key success factors: - Proper parameter abstraction using dedicated data classes - Comprehensive validation for backend-specific requirements - Robust error handling with backend-specific error types - Thorough testing at unit, integration, and performance levels - Careful migration strategy to preserve existing data and functionality The plan addresses all critical code review recommendations and provides a solid foundation for the Higgs TTS integration.