Singleton manages memory well and fast.
This commit is contained in:
parent
2af705ca43
commit
3548485b4e
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,9 @@
|
||||||
|
|
||||||
|
# 2025-06-14 18:21:08.215816
|
||||||
|
+yes
|
||||||
|
|
||||||
|
# 2025-06-14 18:21:29.450580
|
||||||
|
+/model
|
||||||
|
|
||||||
|
# 2025-06-14 18:22:01.292648
|
||||||
|
+/exit
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,36 @@
|
||||||
|
# OpenCode.md
|
||||||
|
|
||||||
|
## Build/Test Commands
|
||||||
|
```bash
|
||||||
|
# Backend setup and run (from project root)
|
||||||
|
pip install -r backend/requirements.txt
|
||||||
|
uvicorn backend.app.main:app --reload --host 0.0.0.0 --port 8000
|
||||||
|
|
||||||
|
# Frontend tests
|
||||||
|
npm test # Run all Jest tests
|
||||||
|
npm test -- --testNamePattern="getSpeakers" # Run single test
|
||||||
|
|
||||||
|
# Backend API test
|
||||||
|
python backend/run_api_test.py
|
||||||
|
|
||||||
|
# Alternative interface
|
||||||
|
python gradio_app.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Code Style Guidelines
|
||||||
|
|
||||||
|
### Python (Backend)
|
||||||
|
- **Imports**: Standard library first, third-party, then local imports with blank lines between groups
|
||||||
|
- **Types**: Use type hints extensively (`List[Speaker]`, `Optional[str]`, `Dict[str, Any]`)
|
||||||
|
- **Classes**: PascalCase (`SpeakerManagementService`, `DialogRequest`)
|
||||||
|
- **Functions/Variables**: snake_case (`get_speakers`, `speaker_id`, `audio_url`)
|
||||||
|
- **Error Handling**: Use FastAPI `HTTPException` with descriptive messages
|
||||||
|
- **Models**: Pydantic models with Field descriptions and validators
|
||||||
|
|
||||||
|
### JavaScript (Frontend)
|
||||||
|
- **Modules**: ES6 modules with explicit imports/exports
|
||||||
|
- **Functions**: camelCase with JSDoc comments (`getSpeakers`, `addSpeaker`)
|
||||||
|
- **Constants**: UPPER_SNAKE_CASE (`API_BASE_URL`)
|
||||||
|
- **Error Handling**: Comprehensive try/catch with detailed error messages
|
||||||
|
- **Async**: Use async/await consistently, handle response.ok checks
|
||||||
|
- **Testing**: Jest with descriptive test names and comprehensive mocking
|
|
@ -4,6 +4,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from app.routers import speakers, dialog # Import the routers
|
from app.routers import speakers, dialog # Import the routers
|
||||||
from app import config
|
from app import config
|
||||||
|
from app.services.tts_service import get_global_tts_service
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Chatterbox TTS API",
|
title="Chatterbox TTS API",
|
||||||
|
@ -37,4 +38,21 @@ config.DIALOG_GENERATED_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
# Mount StaticFiles to serve generated dialogs
|
# Mount StaticFiles to serve generated dialogs
|
||||||
app.mount("/generated_audio", StaticFiles(directory=config.DIALOG_GENERATED_DIR), name="generated_audio")
|
app.mount("/generated_audio", StaticFiles(directory=config.DIALOG_GENERATED_DIR), name="generated_audio")
|
||||||
|
|
||||||
|
# Application lifecycle events for TTS model management
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def startup_event():
|
||||||
|
"""Load TTS model on application startup."""
|
||||||
|
print("🚀 Starting Chatterbox TTS API...")
|
||||||
|
tts_service = get_global_tts_service()
|
||||||
|
tts_service.load_model()
|
||||||
|
print("✅ TTS model loaded and ready!")
|
||||||
|
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
async def shutdown_event():
|
||||||
|
"""Unload TTS model on application shutdown."""
|
||||||
|
print("🔄 Shutting down Chatterbox TTS API...")
|
||||||
|
tts_service = get_global_tts_service()
|
||||||
|
tts_service.unload_model()
|
||||||
|
print("✅ TTS model unloaded. Goodbye!")
|
||||||
|
|
||||||
# Further endpoints for speakers, dialog generation, etc., will be added here.
|
# Further endpoints for speakers, dialog generation, etc., will be added here.
|
||||||
|
|
|
@ -4,7 +4,7 @@ import shutil
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from app.models.dialog_models import DialogRequest, DialogResponse
|
from app.models.dialog_models import DialogRequest, DialogResponse
|
||||||
from app.services.tts_service import TTSService
|
from app.services.tts_service import TTSService, get_global_tts_service
|
||||||
from app.services.speaker_service import SpeakerManagementService
|
from app.services.speaker_service import SpeakerManagementService
|
||||||
from app.services.dialog_processor_service import DialogProcessorService
|
from app.services.dialog_processor_service import DialogProcessorService
|
||||||
from app.services.audio_manipulation_service import AudioManipulationService
|
from app.services.audio_manipulation_service import AudioManipulationService
|
||||||
|
@ -17,8 +17,8 @@ router = APIRouter()
|
||||||
# For now, direct instantiation or simple Depends is fine.
|
# For now, direct instantiation or simple Depends is fine.
|
||||||
|
|
||||||
def get_tts_service():
|
def get_tts_service():
|
||||||
# Consider making device configurable
|
# Return the global singleton instance
|
||||||
return TTSService(device="mps")
|
return get_global_tts_service(device="mps")
|
||||||
|
|
||||||
def get_speaker_management_service():
|
def get_speaker_management_service():
|
||||||
return SpeakerManagementService()
|
return SpeakerManagementService()
|
||||||
|
@ -128,19 +128,7 @@ async def generate_line(
|
||||||
detail=error_detail
|
detail=error_detail
|
||||||
)
|
)
|
||||||
|
|
||||||
async def manage_tts_model_lifecycle(tts_service: TTSService, task_function, *args, **kwargs):
|
# Note: manage_tts_model_lifecycle function removed - model lifecycle now managed at application startup/shutdown
|
||||||
"""Loads TTS model, executes task, then unloads model."""
|
|
||||||
try:
|
|
||||||
print("API: Loading TTS model...")
|
|
||||||
tts_service.load_model()
|
|
||||||
return await task_function(*args, **kwargs)
|
|
||||||
except Exception as e:
|
|
||||||
# Log or handle specific exceptions if needed before re-raising
|
|
||||||
print(f"API: Error during TTS model lifecycle or task execution: {e}")
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
print("API: Unloading TTS model...")
|
|
||||||
tts_service.unload_model()
|
|
||||||
|
|
||||||
async def process_dialog_flow(
|
async def process_dialog_flow(
|
||||||
request: DialogRequest,
|
request: DialogRequest,
|
||||||
|
@ -274,10 +262,8 @@ async def generate_dialog_endpoint(
|
||||||
- Concatenates all audio segments into a single file.
|
- Concatenates all audio segments into a single file.
|
||||||
- Creates a ZIP archive of all individual segments and the concatenated file.
|
- Creates a ZIP archive of all individual segments and the concatenated file.
|
||||||
"""
|
"""
|
||||||
# Wrap the core processing logic with model loading/unloading
|
# Model is now loaded at startup and kept loaded - no per-request lifecycle management needed
|
||||||
return await manage_tts_model_lifecycle(
|
return await process_dialog_flow(
|
||||||
tts_service,
|
|
||||||
process_dialog_flow,
|
|
||||||
request=request,
|
request=request,
|
||||||
dialog_processor=dialog_processor,
|
dialog_processor=dialog_processor,
|
||||||
audio_manipulator=audio_manipulator,
|
audio_manipulator=audio_manipulator,
|
||||||
|
|
|
@ -41,10 +41,22 @@ def safe_load_chatterbox_tts(device):
|
||||||
return ChatterboxTTS.from_pretrained(device=device)
|
return ChatterboxTTS.from_pretrained(device=device)
|
||||||
|
|
||||||
class TTSService:
|
class TTSService:
|
||||||
|
_instance = None
|
||||||
|
_initialized = False
|
||||||
|
|
||||||
|
def __new__(cls, device: str = "mps"):
|
||||||
|
"""Singleton pattern - ensures only one instance exists."""
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super(TTSService, cls).__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
def __init__(self, device: str = "mps"): # Default to MPS for Macs, can be "cpu" or "cuda"
|
def __init__(self, device: str = "mps"): # Default to MPS for Macs, can be "cpu" or "cuda"
|
||||||
|
# Only initialize once to prevent resetting the model
|
||||||
|
if not self._initialized:
|
||||||
self.device = device
|
self.device = device
|
||||||
self.model = None
|
self.model = None
|
||||||
self._ensure_output_dir_exists()
|
self._ensure_output_dir_exists()
|
||||||
|
TTSService._initialized = True
|
||||||
|
|
||||||
def _ensure_output_dir_exists(self):
|
def _ensure_output_dir_exists(self):
|
||||||
"""Ensures the TTS output directory exists."""
|
"""Ensures the TTS output directory exists."""
|
||||||
|
@ -62,12 +74,12 @@ class TTSService:
|
||||||
# Potentially raise an exception or handle appropriately
|
# Potentially raise an exception or handle appropriately
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
print("ChatterboxTTS model already loaded.")
|
print("[Singleton] ChatterboxTTS model already loaded.")
|
||||||
|
|
||||||
def unload_model(self):
|
def unload_model(self):
|
||||||
"""Unloads the model and clears memory."""
|
"""Unloads the model and clears memory."""
|
||||||
if self.model is not None:
|
if self.model is not None:
|
||||||
print("Unloading ChatterboxTTS model and clearing cache...")
|
print("[Singleton] Unloading ChatterboxTTS model and clearing cache...")
|
||||||
del self.model
|
del self.model
|
||||||
self.model = None
|
self.model = None
|
||||||
if self.device == "cuda":
|
if self.device == "cuda":
|
||||||
|
@ -76,7 +88,9 @@ class TTSService:
|
||||||
if hasattr(torch.mps, "empty_cache"): # Check if empty_cache is available for MPS
|
if hasattr(torch.mps, "empty_cache"): # Check if empty_cache is available for MPS
|
||||||
torch.mps.empty_cache()
|
torch.mps.empty_cache()
|
||||||
gc.collect() # Explicitly run garbage collection
|
gc.collect() # Explicitly run garbage collection
|
||||||
print("Model unloaded and memory cleared.")
|
print("[Singleton] Model unloaded and memory cleared.")
|
||||||
|
else:
|
||||||
|
print("[Singleton] Model was not loaded, nothing to unload.")
|
||||||
|
|
||||||
async def generate_speech(
|
async def generate_speech(
|
||||||
self,
|
self,
|
||||||
|
@ -94,10 +108,7 @@ class TTSService:
|
||||||
Saves the output to a .wav file.
|
Saves the output to a .wav file.
|
||||||
"""
|
"""
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
self.load_model()
|
raise RuntimeError("TTS model is not loaded. Model should be loaded at application startup.")
|
||||||
|
|
||||||
if self.model is None: # Check again if loading failed
|
|
||||||
raise RuntimeError("TTS model is not loaded. Cannot generate speech.")
|
|
||||||
|
|
||||||
# Ensure speaker_sample_path is valid
|
# Ensure speaker_sample_path is valid
|
||||||
speaker_sample_p = Path(speaker_sample_path)
|
speaker_sample_p = Path(speaker_sample_path)
|
||||||
|
@ -130,10 +141,20 @@ class TTSService:
|
||||||
# For now, we keep it loaded. Memory management might need refinement.
|
# For now, we keep it loaded. Memory management might need refinement.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Global singleton instance access
|
||||||
|
_global_tts_service = None
|
||||||
|
|
||||||
|
def get_global_tts_service(device: str = "mps") -> TTSService:
|
||||||
|
"""Get the global singleton TTS service instance."""
|
||||||
|
global _global_tts_service
|
||||||
|
if _global_tts_service is None:
|
||||||
|
_global_tts_service = TTSService(device=device)
|
||||||
|
return _global_tts_service
|
||||||
|
|
||||||
# Example usage (for testing, not part of the service itself)
|
# Example usage (for testing, not part of the service itself)
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
async def main_test():
|
async def main_test():
|
||||||
tts_service = TTSService(device="mps")
|
tts_service = get_global_tts_service(device="mps")
|
||||||
try:
|
try:
|
||||||
tts_service.load_model()
|
tts_service.load_model()
|
||||||
|
|
||||||
|
|
|
@ -28,3 +28,15 @@ dd3552d9-f4e8-49ed-9892-f9e67afcf23c:
|
||||||
2cdd6d3d-c533-44bf-a5f6-cc83bd089d32:
|
2cdd6d3d-c533-44bf-a5f6-cc83bd089d32:
|
||||||
name: Grace
|
name: Grace
|
||||||
sample_path: speaker_samples/2cdd6d3d-c533-44bf-a5f6-cc83bd089d32.wav
|
sample_path: speaker_samples/2cdd6d3d-c533-44bf-a5f6-cc83bd089d32.wav
|
||||||
|
fdbfa71b-7647-4574-a1c0-31350348b434:
|
||||||
|
name: Elthea
|
||||||
|
sample_path: speaker_samples/fdbfa71b-7647-4574-a1c0-31350348b434.wav
|
||||||
|
44cfc6c1-78ec-4278-920a-8ad067cd1eba:
|
||||||
|
name: Eddie
|
||||||
|
sample_path: speaker_samples/44cfc6c1-78ec-4278-920a-8ad067cd1eba.wav
|
||||||
|
a25c52cc-ad56-46d2-9209-62fa7aebb150:
|
||||||
|
name: Charlotte
|
||||||
|
sample_path: speaker_samples/a25c52cc-ad56-46d2-9209-62fa7aebb150.wav
|
||||||
|
aeb43113-586c-4ab8-86e6-3b26737b9816:
|
||||||
|
name: Announcer1
|
||||||
|
sample_path: speaker_samples/aeb43113-586c-4ab8-86e6-3b26737b9816.wav
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
#!/Users/stwhite/CODE/chatterbox-ui/.venv/bin/python
|
#!/Volumes/SAM2/CODE/chatterbox-test/.venv/bin/python
|
||||||
"""
|
"""
|
||||||
Startup script that launches both the backend and frontend servers concurrently.
|
Startup script that launches both the backend and frontend servers concurrently.
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue