Clean up memory management in cbx-audiobook.py
- Use singleton pattern from TTSService for efficient model management - Remove complex manual memory cleanup code - Simplify CLI arguments by removing redundant memory management options - Load model once at start, let singleton handle efficient reuse - Remove keep-model-loaded and cleanup-interval options - Streamline generation logic to match backend service patterns 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
a983c31e54
commit
0e522feddf
110
cbx-audiobook.py
110
cbx-audiobook.py
|
@ -31,7 +31,7 @@ class AudiobookGenerator:
|
||||||
def __init__(self, speaker_id, output_base_name, device="mps",
|
def __init__(self, speaker_id, output_base_name, device="mps",
|
||||||
exaggeration=0.5, cfg_weight=0.5, temperature=0.8,
|
exaggeration=0.5, cfg_weight=0.5, temperature=0.8,
|
||||||
pause_between_sentences=0.5, pause_between_paragraphs=1.0,
|
pause_between_sentences=0.5, pause_between_paragraphs=1.0,
|
||||||
keep_model_loaded=False, cleanup_interval=10, use_subprocess=False):
|
use_subprocess=False):
|
||||||
"""
|
"""
|
||||||
Initialize the audiobook generator.
|
Initialize the audiobook generator.
|
||||||
|
|
||||||
|
@ -44,8 +44,6 @@ class AudiobookGenerator:
|
||||||
temperature: Controls randomness in generation (0.0-1.0)
|
temperature: Controls randomness in generation (0.0-1.0)
|
||||||
pause_between_sentences: Pause duration between sentences in seconds
|
pause_between_sentences: Pause duration between sentences in seconds
|
||||||
pause_between_paragraphs: Pause duration between paragraphs in seconds
|
pause_between_paragraphs: Pause duration between paragraphs in seconds
|
||||||
keep_model_loaded: If True, keeps model loaded across chunks (more efficient but uses more memory)
|
|
||||||
cleanup_interval: How often to perform deep cleanup when keep_model_loaded=True
|
|
||||||
use_subprocess: If True, uses separate processes for each chunk (slower but guarantees memory release)
|
use_subprocess: If True, uses separate processes for each chunk (slower but guarantees memory release)
|
||||||
"""
|
"""
|
||||||
self.speaker_id = speaker_id
|
self.speaker_id = speaker_id
|
||||||
|
@ -56,10 +54,7 @@ class AudiobookGenerator:
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.pause_between_sentences = pause_between_sentences
|
self.pause_between_sentences = pause_between_sentences
|
||||||
self.pause_between_paragraphs = pause_between_paragraphs
|
self.pause_between_paragraphs = pause_between_paragraphs
|
||||||
self.keep_model_loaded = keep_model_loaded
|
|
||||||
self.cleanup_interval = cleanup_interval
|
|
||||||
self.use_subprocess = use_subprocess
|
self.use_subprocess = use_subprocess
|
||||||
self.chunk_counter = 0
|
|
||||||
|
|
||||||
# Initialize services
|
# Initialize services
|
||||||
self.tts_service = TTSService(device=device)
|
self.tts_service = TTSService(device=device)
|
||||||
|
@ -86,47 +81,6 @@ class AudiobookGenerator:
|
||||||
# Store speaker info for later use
|
# Store speaker info for later use
|
||||||
self.speaker_info = speaker_info
|
self.speaker_info = speaker_info
|
||||||
|
|
||||||
def _cleanup_memory(self):
|
|
||||||
"""Force memory cleanup and garbage collection."""
|
|
||||||
print("Performing memory cleanup...")
|
|
||||||
|
|
||||||
# Force garbage collection multiple times for thorough cleanup
|
|
||||||
for _ in range(3):
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
# Clear device-specific caches
|
|
||||||
if self.device == "cuda" and torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
# Additional CUDA cleanup
|
|
||||||
try:
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
elif self.device == "mps" and torch.backends.mps.is_available():
|
|
||||||
if hasattr(torch.mps, "empty_cache"):
|
|
||||||
torch.mps.empty_cache()
|
|
||||||
if hasattr(torch.mps, "synchronize"):
|
|
||||||
torch.mps.synchronize()
|
|
||||||
# Try to free MPS memory more aggressively
|
|
||||||
try:
|
|
||||||
import os
|
|
||||||
# This forces MPS to release memory back to the system
|
|
||||||
if hasattr(torch.mps, "set_per_process_memory_fraction"):
|
|
||||||
current_allocated = torch.mps.current_allocated_memory() if hasattr(torch.mps, "current_allocated_memory") else 0
|
|
||||||
if current_allocated > 0:
|
|
||||||
torch.mps.empty_cache()
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Additional aggressive cleanup
|
|
||||||
if hasattr(torch, '_C') and hasattr(torch._C, '_cuda_clearCublasWorkspaces'):
|
|
||||||
try:
|
|
||||||
torch._C._cuda_clearCublasWorkspaces()
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
print("Memory cleanup completed.")
|
|
||||||
|
|
||||||
async def _generate_chunk_subprocess(self, chunk, segment_filename_base, speaker_sample_path):
|
async def _generate_chunk_subprocess(self, chunk, segment_filename_base, speaker_sample_path):
|
||||||
"""
|
"""
|
||||||
|
@ -250,10 +204,9 @@ class AudiobookGenerator:
|
||||||
segment_results = []
|
segment_results = []
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
|
|
||||||
# Pre-load model if keeping it loaded
|
# Load model once at the start (singleton will handle reuse)
|
||||||
if self.keep_model_loaded:
|
print("Loading TTS model...")
|
||||||
print("Pre-loading TTS model for batch processing...")
|
self.tts_service.load_model()
|
||||||
self.tts_service.load_model()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for para_idx, paragraph in enumerate(paragraphs):
|
for para_idx, paragraph in enumerate(paragraphs):
|
||||||
|
@ -261,7 +214,6 @@ class AudiobookGenerator:
|
||||||
|
|
||||||
for chunk_idx, chunk in enumerate(paragraph["chunks"]):
|
for chunk_idx, chunk in enumerate(paragraph["chunks"]):
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
self.chunk_counter += 1
|
|
||||||
print(f" Generating audio for chunk {chunk_count}/{total_chunks}: {chunk[:50]}...")
|
print(f" Generating audio for chunk {chunk_count}/{total_chunks}: {chunk[:50]}...")
|
||||||
|
|
||||||
# Generate unique filename for this chunk
|
# Generate unique filename for this chunk
|
||||||
|
@ -283,12 +235,7 @@ class AudiobookGenerator:
|
||||||
speaker_sample_path=speaker_sample_path
|
speaker_sample_path=speaker_sample_path
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Load model for this chunk (if not keeping loaded)
|
# Generate speech using the TTS service (model already loaded)
|
||||||
if not self.keep_model_loaded:
|
|
||||||
print("Loading TTS model...")
|
|
||||||
self.tts_service.load_model()
|
|
||||||
|
|
||||||
# Generate speech using the TTS service
|
|
||||||
segment_output_path = await self.tts_service.generate_speech(
|
segment_output_path = await self.tts_service.generate_speech(
|
||||||
text=chunk,
|
text=chunk,
|
||||||
speaker_id=self.speaker_id,
|
speaker_id=self.speaker_id,
|
||||||
|
@ -300,26 +247,6 @@ class AudiobookGenerator:
|
||||||
temperature=self.temperature
|
temperature=self.temperature
|
||||||
)
|
)
|
||||||
|
|
||||||
# Memory management strategy based on model lifecycle
|
|
||||||
if self.use_subprocess:
|
|
||||||
# No memory management needed - subprocess handles it
|
|
||||||
pass
|
|
||||||
elif self.keep_model_loaded:
|
|
||||||
# Light cleanup after each chunk
|
|
||||||
if self.chunk_counter % self.cleanup_interval == 0:
|
|
||||||
print(f"Performing periodic deep cleanup (chunk {self.chunk_counter})")
|
|
||||||
self._cleanup_memory()
|
|
||||||
else:
|
|
||||||
# Explicit memory cleanup after generation
|
|
||||||
self._cleanup_memory()
|
|
||||||
|
|
||||||
# Unload model after generation
|
|
||||||
print("Unloading TTS model...")
|
|
||||||
self.tts_service.unload_model()
|
|
||||||
|
|
||||||
# Additional memory cleanup after model unload
|
|
||||||
self._cleanup_memory()
|
|
||||||
|
|
||||||
# Add to segment results
|
# Add to segment results
|
||||||
segment_results.append({
|
segment_results.append({
|
||||||
"type": "speech",
|
"type": "speech",
|
||||||
|
@ -335,13 +262,6 @@ class AudiobookGenerator:
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error generating speech for chunk: {e}")
|
print(f"Error generating speech for chunk: {e}")
|
||||||
# Ensure model is unloaded if there was an error and not using subprocess
|
|
||||||
if not self.use_subprocess:
|
|
||||||
if not self.keep_model_loaded and self.tts_service.model is not None:
|
|
||||||
print("Unloading TTS model after error...")
|
|
||||||
self.tts_service.unload_model()
|
|
||||||
# Force cleanup after error
|
|
||||||
self._cleanup_memory()
|
|
||||||
# Continue with next chunk
|
# Continue with next chunk
|
||||||
|
|
||||||
# Add longer pause between paragraphs
|
# Add longer pause between paragraphs
|
||||||
|
@ -352,11 +272,10 @@ class AudiobookGenerator:
|
||||||
})
|
})
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Always unload model at the end if it was kept loaded
|
# Optionally unload model at the end (singleton manages this efficiently)
|
||||||
if self.keep_model_loaded and self.tts_service.model is not None:
|
if not self.use_subprocess:
|
||||||
print("Final cleanup: Unloading TTS model...")
|
print("Unloading TTS model...")
|
||||||
self.tts_service.unload_model()
|
self.tts_service.unload_model()
|
||||||
self._cleanup_memory()
|
|
||||||
|
|
||||||
# Concatenate all segments
|
# Concatenate all segments
|
||||||
print("Concatenating audio segments...")
|
print("Concatenating audio segments...")
|
||||||
|
@ -389,11 +308,6 @@ class AudiobookGenerator:
|
||||||
print(f"Audiobook file: {concatenated_path}")
|
print(f"Audiobook file: {concatenated_path}")
|
||||||
print(f"ZIP archive: {zip_path}")
|
print(f"ZIP archive: {zip_path}")
|
||||||
|
|
||||||
# Ensure model is unloaded at the end (just in case)
|
|
||||||
if self.tts_service.model is not None:
|
|
||||||
print("Final check: Unloading TTS model...")
|
|
||||||
self.tts_service.unload_model()
|
|
||||||
|
|
||||||
return concatenated_path
|
return concatenated_path
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
|
@ -413,11 +327,9 @@ async def main():
|
||||||
parser.add_argument("--temperature", type=float, default=0.8, help="Controls randomness (0.0-1.0, default: 0.8)")
|
parser.add_argument("--temperature", type=float, default=0.8, help="Controls randomness (0.0-1.0, default: 0.8)")
|
||||||
parser.add_argument("--sentence-pause", type=float, default=0.5, help="Pause between sentences in seconds (default: 0.5)")
|
parser.add_argument("--sentence-pause", type=float, default=0.5, help="Pause between sentences in seconds (default: 0.5)")
|
||||||
parser.add_argument("--paragraph-pause", type=float, default=1.0, help="Pause between paragraphs in seconds (default: 1.0)")
|
parser.add_argument("--paragraph-pause", type=float, default=1.0, help="Pause between paragraphs in seconds (default: 1.0)")
|
||||||
parser.add_argument("--keep-model-loaded", action="store_true", help="Keep model loaded between chunks (faster but uses more memory)")
|
|
||||||
parser.add_argument("--cleanup-interval", type=int, default=10, help="How often to perform deep cleanup when keeping model loaded (default: 10)")
|
|
||||||
parser.add_argument("--force-cpu-on-oom", action="store_true", help="Automatically switch to CPU if MPS/CUDA runs out of memory")
|
parser.add_argument("--force-cpu-on-oom", action="store_true", help="Automatically switch to CPU if MPS/CUDA runs out of memory")
|
||||||
parser.add_argument("--max-chunk-length", type=int, default=300, help="Maximum chunk length for text splitting (default: 300)")
|
parser.add_argument("--max-chunk-length", type=int, default=300, help="Maximum chunk length for text splitting (default: 300)")
|
||||||
parser.add_argument("--use-subprocess", action="store_true", help="Use separate processes for each chunk (guarantees memory release but slower)")
|
parser.add_argument("--use-subprocess", action="store_true", help="Use separate processes for each chunk (slower but reduces memory usage)")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@ -453,8 +365,6 @@ async def main():
|
||||||
temperature=args.temperature,
|
temperature=args.temperature,
|
||||||
pause_between_sentences=args.sentence_pause,
|
pause_between_sentences=args.sentence_pause,
|
||||||
pause_between_paragraphs=args.paragraph_pause,
|
pause_between_paragraphs=args.paragraph_pause,
|
||||||
keep_model_loaded=args.keep_model_loaded,
|
|
||||||
cleanup_interval=args.cleanup_interval,
|
|
||||||
use_subprocess=args.use_subprocess
|
use_subprocess=args.use_subprocess
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -476,8 +386,6 @@ async def main():
|
||||||
temperature=args.temperature,
|
temperature=args.temperature,
|
||||||
pause_between_sentences=args.sentence_pause,
|
pause_between_sentences=args.sentence_pause,
|
||||||
pause_between_paragraphs=args.paragraph_pause,
|
pause_between_paragraphs=args.paragraph_pause,
|
||||||
keep_model_loaded=args.keep_model_loaded,
|
|
||||||
cleanup_interval=args.cleanup_interval,
|
|
||||||
use_subprocess=args.use_subprocess
|
use_subprocess=args.use_subprocess
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue