Clean up memory management in cbx-audiobook.py

- Use singleton pattern from TTSService for efficient model management
- Remove complex manual memory cleanup code
- Simplify CLI arguments by removing redundant memory management options
- Load model once at start, let singleton handle efficient reuse
- Remove keep-model-loaded and cleanup-interval options
- Streamline generation logic to match backend service patterns

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Steve White 2025-06-27 00:01:13 -05:00
parent a983c31e54
commit 0e522feddf
1 changed files with 9 additions and 101 deletions

View File

@ -31,7 +31,7 @@ class AudiobookGenerator:
def __init__(self, speaker_id, output_base_name, device="mps", def __init__(self, speaker_id, output_base_name, device="mps",
exaggeration=0.5, cfg_weight=0.5, temperature=0.8, exaggeration=0.5, cfg_weight=0.5, temperature=0.8,
pause_between_sentences=0.5, pause_between_paragraphs=1.0, pause_between_sentences=0.5, pause_between_paragraphs=1.0,
keep_model_loaded=False, cleanup_interval=10, use_subprocess=False): use_subprocess=False):
""" """
Initialize the audiobook generator. Initialize the audiobook generator.
@ -44,8 +44,6 @@ class AudiobookGenerator:
temperature: Controls randomness in generation (0.0-1.0) temperature: Controls randomness in generation (0.0-1.0)
pause_between_sentences: Pause duration between sentences in seconds pause_between_sentences: Pause duration between sentences in seconds
pause_between_paragraphs: Pause duration between paragraphs in seconds pause_between_paragraphs: Pause duration between paragraphs in seconds
keep_model_loaded: If True, keeps model loaded across chunks (more efficient but uses more memory)
cleanup_interval: How often to perform deep cleanup when keep_model_loaded=True
use_subprocess: If True, uses separate processes for each chunk (slower but guarantees memory release) use_subprocess: If True, uses separate processes for each chunk (slower but guarantees memory release)
""" """
self.speaker_id = speaker_id self.speaker_id = speaker_id
@ -56,10 +54,7 @@ class AudiobookGenerator:
self.temperature = temperature self.temperature = temperature
self.pause_between_sentences = pause_between_sentences self.pause_between_sentences = pause_between_sentences
self.pause_between_paragraphs = pause_between_paragraphs self.pause_between_paragraphs = pause_between_paragraphs
self.keep_model_loaded = keep_model_loaded
self.cleanup_interval = cleanup_interval
self.use_subprocess = use_subprocess self.use_subprocess = use_subprocess
self.chunk_counter = 0
# Initialize services # Initialize services
self.tts_service = TTSService(device=device) self.tts_service = TTSService(device=device)
@ -86,47 +81,6 @@ class AudiobookGenerator:
# Store speaker info for later use # Store speaker info for later use
self.speaker_info = speaker_info self.speaker_info = speaker_info
def _cleanup_memory(self):
"""Force memory cleanup and garbage collection."""
print("Performing memory cleanup...")
# Force garbage collection multiple times for thorough cleanup
for _ in range(3):
gc.collect()
# Clear device-specific caches
if self.device == "cuda" and torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Additional CUDA cleanup
try:
torch.cuda.reset_peak_memory_stats()
except:
pass
elif self.device == "mps" and torch.backends.mps.is_available():
if hasattr(torch.mps, "empty_cache"):
torch.mps.empty_cache()
if hasattr(torch.mps, "synchronize"):
torch.mps.synchronize()
# Try to free MPS memory more aggressively
try:
import os
# This forces MPS to release memory back to the system
if hasattr(torch.mps, "set_per_process_memory_fraction"):
current_allocated = torch.mps.current_allocated_memory() if hasattr(torch.mps, "current_allocated_memory") else 0
if current_allocated > 0:
torch.mps.empty_cache()
except:
pass
# Additional aggressive cleanup
if hasattr(torch, '_C') and hasattr(torch._C, '_cuda_clearCublasWorkspaces'):
try:
torch._C._cuda_clearCublasWorkspaces()
except:
pass
print("Memory cleanup completed.")
async def _generate_chunk_subprocess(self, chunk, segment_filename_base, speaker_sample_path): async def _generate_chunk_subprocess(self, chunk, segment_filename_base, speaker_sample_path):
""" """
@ -250,9 +204,8 @@ class AudiobookGenerator:
segment_results = [] segment_results = []
chunk_count = 0 chunk_count = 0
# Pre-load model if keeping it loaded # Load model once at the start (singleton will handle reuse)
if self.keep_model_loaded: print("Loading TTS model...")
print("Pre-loading TTS model for batch processing...")
self.tts_service.load_model() self.tts_service.load_model()
try: try:
@ -261,7 +214,6 @@ class AudiobookGenerator:
for chunk_idx, chunk in enumerate(paragraph["chunks"]): for chunk_idx, chunk in enumerate(paragraph["chunks"]):
chunk_count += 1 chunk_count += 1
self.chunk_counter += 1
print(f" Generating audio for chunk {chunk_count}/{total_chunks}: {chunk[:50]}...") print(f" Generating audio for chunk {chunk_count}/{total_chunks}: {chunk[:50]}...")
# Generate unique filename for this chunk # Generate unique filename for this chunk
@ -283,12 +235,7 @@ class AudiobookGenerator:
speaker_sample_path=speaker_sample_path speaker_sample_path=speaker_sample_path
) )
else: else:
# Load model for this chunk (if not keeping loaded) # Generate speech using the TTS service (model already loaded)
if not self.keep_model_loaded:
print("Loading TTS model...")
self.tts_service.load_model()
# Generate speech using the TTS service
segment_output_path = await self.tts_service.generate_speech( segment_output_path = await self.tts_service.generate_speech(
text=chunk, text=chunk,
speaker_id=self.speaker_id, speaker_id=self.speaker_id,
@ -300,26 +247,6 @@ class AudiobookGenerator:
temperature=self.temperature temperature=self.temperature
) )
# Memory management strategy based on model lifecycle
if self.use_subprocess:
# No memory management needed - subprocess handles it
pass
elif self.keep_model_loaded:
# Light cleanup after each chunk
if self.chunk_counter % self.cleanup_interval == 0:
print(f"Performing periodic deep cleanup (chunk {self.chunk_counter})")
self._cleanup_memory()
else:
# Explicit memory cleanup after generation
self._cleanup_memory()
# Unload model after generation
print("Unloading TTS model...")
self.tts_service.unload_model()
# Additional memory cleanup after model unload
self._cleanup_memory()
# Add to segment results # Add to segment results
segment_results.append({ segment_results.append({
"type": "speech", "type": "speech",
@ -335,13 +262,6 @@ class AudiobookGenerator:
except Exception as e: except Exception as e:
print(f"Error generating speech for chunk: {e}") print(f"Error generating speech for chunk: {e}")
# Ensure model is unloaded if there was an error and not using subprocess
if not self.use_subprocess:
if not self.keep_model_loaded and self.tts_service.model is not None:
print("Unloading TTS model after error...")
self.tts_service.unload_model()
# Force cleanup after error
self._cleanup_memory()
# Continue with next chunk # Continue with next chunk
# Add longer pause between paragraphs # Add longer pause between paragraphs
@ -352,11 +272,10 @@ class AudiobookGenerator:
}) })
finally: finally:
# Always unload model at the end if it was kept loaded # Optionally unload model at the end (singleton manages this efficiently)
if self.keep_model_loaded and self.tts_service.model is not None: if not self.use_subprocess:
print("Final cleanup: Unloading TTS model...") print("Unloading TTS model...")
self.tts_service.unload_model() self.tts_service.unload_model()
self._cleanup_memory()
# Concatenate all segments # Concatenate all segments
print("Concatenating audio segments...") print("Concatenating audio segments...")
@ -389,11 +308,6 @@ class AudiobookGenerator:
print(f"Audiobook file: {concatenated_path}") print(f"Audiobook file: {concatenated_path}")
print(f"ZIP archive: {zip_path}") print(f"ZIP archive: {zip_path}")
# Ensure model is unloaded at the end (just in case)
if self.tts_service.model is not None:
print("Final check: Unloading TTS model...")
self.tts_service.unload_model()
return concatenated_path return concatenated_path
async def main(): async def main():
@ -413,11 +327,9 @@ async def main():
parser.add_argument("--temperature", type=float, default=0.8, help="Controls randomness (0.0-1.0, default: 0.8)") parser.add_argument("--temperature", type=float, default=0.8, help="Controls randomness (0.0-1.0, default: 0.8)")
parser.add_argument("--sentence-pause", type=float, default=0.5, help="Pause between sentences in seconds (default: 0.5)") parser.add_argument("--sentence-pause", type=float, default=0.5, help="Pause between sentences in seconds (default: 0.5)")
parser.add_argument("--paragraph-pause", type=float, default=1.0, help="Pause between paragraphs in seconds (default: 1.0)") parser.add_argument("--paragraph-pause", type=float, default=1.0, help="Pause between paragraphs in seconds (default: 1.0)")
parser.add_argument("--keep-model-loaded", action="store_true", help="Keep model loaded between chunks (faster but uses more memory)")
parser.add_argument("--cleanup-interval", type=int, default=10, help="How often to perform deep cleanup when keeping model loaded (default: 10)")
parser.add_argument("--force-cpu-on-oom", action="store_true", help="Automatically switch to CPU if MPS/CUDA runs out of memory") parser.add_argument("--force-cpu-on-oom", action="store_true", help="Automatically switch to CPU if MPS/CUDA runs out of memory")
parser.add_argument("--max-chunk-length", type=int, default=300, help="Maximum chunk length for text splitting (default: 300)") parser.add_argument("--max-chunk-length", type=int, default=300, help="Maximum chunk length for text splitting (default: 300)")
parser.add_argument("--use-subprocess", action="store_true", help="Use separate processes for each chunk (guarantees memory release but slower)") parser.add_argument("--use-subprocess", action="store_true", help="Use separate processes for each chunk (slower but reduces memory usage)")
args = parser.parse_args() args = parser.parse_args()
@ -453,8 +365,6 @@ async def main():
temperature=args.temperature, temperature=args.temperature,
pause_between_sentences=args.sentence_pause, pause_between_sentences=args.sentence_pause,
pause_between_paragraphs=args.paragraph_pause, pause_between_paragraphs=args.paragraph_pause,
keep_model_loaded=args.keep_model_loaded,
cleanup_interval=args.cleanup_interval,
use_subprocess=args.use_subprocess use_subprocess=args.use_subprocess
) )
@ -476,8 +386,6 @@ async def main():
temperature=args.temperature, temperature=args.temperature,
pause_between_sentences=args.sentence_pause, pause_between_sentences=args.sentence_pause,
pause_between_paragraphs=args.paragraph_pause, pause_between_paragraphs=args.paragraph_pause,
keep_model_loaded=args.keep_model_loaded,
cleanup_interval=args.cleanup_interval,
use_subprocess=args.use_subprocess use_subprocess=args.use_subprocess
) )