chatterbox-ui/backend/app/services/audio_manipulation_service.py

242 lines
11 KiB
Python

import torch
import torchaudio
from pathlib import Path
from typing import List, Dict, Union, Tuple
import zipfile
# Define a common sample rate, e.g., from the TTS model. This should ideally be configurable or dynamically obtained.
# For now, let's assume the TTS model (ChatterboxTTS) outputs at a known sample rate.
# The ChatterboxTTS model.sr is 24000.
DEFAULT_SAMPLE_RATE = 24000
class AudioManipulationService:
def __init__(self, default_sample_rate: int = DEFAULT_SAMPLE_RATE):
self.sample_rate = default_sample_rate
def _load_audio(self, file_path: Union[str, Path]) -> Tuple[torch.Tensor, int]:
"""Loads an audio file and returns the waveform and sample rate."""
try:
waveform, sr = torchaudio.load(file_path)
return waveform, sr
except Exception as e:
raise RuntimeError(f"Error loading audio file {file_path}: {e}")
def _create_silence(self, duration_seconds: float) -> torch.Tensor:
"""Creates a silent audio tensor of a given duration."""
num_frames = int(duration_seconds * self.sample_rate)
return torch.zeros((1, num_frames)) # Mono silence
def concatenate_audio_segments(
self,
segment_results: List[Dict],
output_concatenated_path: Path
) -> Path:
"""
Concatenates audio segments and silences into a single audio file.
Args:
segment_results: A list of dictionaries, where each dict represents an audio
segment or a silence. Expected format:
For speech: {'type': 'speech', 'path': 'path/to/audio.wav', ...}
For silence: {'type': 'silence', 'duration': 0.5, ...}
output_concatenated_path: The path to save the final concatenated audio file.
Returns:
The path to the concatenated audio file.
"""
all_waveforms: List[torch.Tensor] = []
current_sample_rate = self.sample_rate # Assume this initially, verify with first loaded audio
for i, segment_info in enumerate(segment_results):
segment_type = segment_info.get("type")
if segment_type == "speech":
audio_path_str = segment_info.get("path")
if not audio_path_str:
print(f"Warning: Speech segment {i} has no path. Skipping.")
continue
audio_path = Path(audio_path_str)
if not audio_path.exists():
print(f"Warning: Audio file {audio_path} for segment {i} not found. Skipping.")
continue
try:
waveform, sr = self._load_audio(audio_path)
# Ensure consistent sample rate. Resample if necessary.
# For simplicity, this example assumes all inputs will match self.sample_rate
# or the first loaded audio's sample rate. A more robust implementation
# would resample if sr != current_sample_rate.
if i == 0 and not all_waveforms: # First audio segment sets the reference SR if not default
current_sample_rate = sr
if sr != self.sample_rate:
print(f"Warning: First audio segment SR ({sr} Hz) differs from service default SR ({self.sample_rate} Hz). Using segment SR.")
if sr != current_sample_rate:
print(f"Warning: Sample rate mismatch for {audio_path} ({sr} Hz) vs expected ({current_sample_rate} Hz). Resampling...")
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=current_sample_rate)
waveform = resampler(waveform)
# Ensure mono. If stereo, take the mean or first channel.
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
all_waveforms.append(waveform)
except Exception as e:
print(f"Error processing speech segment {audio_path}: {e}. Skipping.")
elif segment_type == "silence":
duration = segment_info.get("duration")
if duration is None or not isinstance(duration, (int, float)) or duration < 0:
print(f"Warning: Silence segment {i} has invalid duration. Skipping.")
continue
silence_waveform = self._create_silence(float(duration))
all_waveforms.append(silence_waveform)
elif segment_type == "error":
# Errors are already logged by DialogProcessorService, just skip here.
print(f"Skipping segment {i} due to previous error: {segment_info.get('message')}")
continue
else:
print(f"Warning: Unknown segment type '{segment_type}' at index {i}. Skipping.")
if not all_waveforms:
raise ValueError("No valid audio segments or silences found to concatenate.")
# Concatenate all waveforms
final_waveform = torch.cat(all_waveforms, dim=1)
# Ensure output directory exists
output_concatenated_path.parent.mkdir(parents=True, exist_ok=True)
# Save the concatenated audio
try:
torchaudio.save(str(output_concatenated_path), final_waveform, current_sample_rate)
print(f"Concatenated audio saved to: {output_concatenated_path}")
return output_concatenated_path
except Exception as e:
raise RuntimeError(f"Error saving concatenated audio to {output_concatenated_path}: {e}")
def create_zip_archive(
self,
segment_file_paths: List[Path],
concatenated_audio_path: Path,
output_zip_path: Path
) -> Path:
"""
Creates a ZIP archive containing individual audio segments and the concatenated audio file.
Args:
segment_file_paths: A list of paths to the individual audio segment files.
concatenated_audio_path: Path to the final concatenated audio file.
output_zip_path: The path to save the output ZIP archive.
Returns:
The path to the created ZIP archive.
"""
output_zip_path.parent.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(output_zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
# Add concatenated audio
if concatenated_audio_path.exists():
zf.write(concatenated_audio_path, arcname=concatenated_audio_path.name)
else:
print(f"Warning: Concatenated audio file {concatenated_audio_path} not found for zipping.")
# Add individual segments
segments_dir_name = "segments"
for file_path in segment_file_paths:
if file_path.exists() and file_path.is_file():
# Store segments in a subdirectory within the zip for organization
zf.write(file_path, arcname=Path(segments_dir_name) / file_path.name)
else:
print(f"Warning: Segment file {file_path} not found or is not a file. Skipping for zipping.")
print(f"ZIP archive created at: {output_zip_path}")
return output_zip_path
# Example Usage (Test Block)
if __name__ == "__main__":
import tempfile
import shutil
# Create a temporary directory for test files
test_temp_dir = Path(tempfile.mkdtemp(prefix="audio_manip_test_"))
print(f"Created temporary test directory: {test_temp_dir}")
# Instance of the service
audio_service = AudioManipulationService()
# --- Test Data Setup ---
# Create dummy audio files (e.g., short silences with different names)
dummy_sr = audio_service.sample_rate
segment1_path = test_temp_dir / "segment1_speech.wav"
segment2_path = test_temp_dir / "segment2_speech.wav"
torchaudio.save(str(segment1_path), audio_service._create_silence(1.0), dummy_sr)
# Create a dummy segment with a different sample rate to test resampling
dummy_sr_alt = 16000
temp_waveform_alt_sr = torch.rand((1, int(0.5 * dummy_sr_alt))) # 0.5s at 16kHz
torchaudio.save(str(segment2_path), temp_waveform_alt_sr, dummy_sr_alt)
segment_results_for_concat = [
{"type": "speech", "path": str(segment1_path), "speaker_id": "spk1", "text_chunk": "Test 1"},
{"type": "silence", "duration": 0.5},
{"type": "speech", "path": str(segment2_path), "speaker_id": "spk2", "text_chunk": "Test 2 (alt SR)"},
{"type": "error", "message": "Simulated error, should be skipped"},
{"type": "speech", "path": "non_existent_segment.wav"}, # Test non-existent file
{"type": "silence", "duration": -0.2} # Test invalid duration
]
concatenated_output_path = test_temp_dir / "final_concatenated_audio.wav"
zip_output_path = test_temp_dir / "audio_archive.zip"
all_segment_files_for_zip = [segment1_path, segment2_path]
try:
# Test concatenation
print("\n--- Testing Concatenation ---")
actual_concat_path = audio_service.concatenate_audio_segments(
segment_results_for_concat,
concatenated_output_path
)
print(f"Concatenation test successful. Output: {actual_concat_path}")
assert actual_concat_path.exists()
# Basic check: load concatenated and verify duration (approx)
concat_wav, concat_sr = audio_service._load_audio(actual_concat_path)
expected_duration = 1.0 + 0.5 + 0.5 # seg1 (1.0s) + silence (0.5s) + seg2 (0.5s) = 2.0s
actual_duration = concat_wav.shape[1] / concat_sr
print(f"Expected duration (approx): {expected_duration}s, Actual duration: {actual_duration:.2f}s")
assert abs(actual_duration - expected_duration) < 0.1 # Allow small deviation
# Test Zipping
print("\n--- Testing Zipping ---")
actual_zip_path = audio_service.create_zip_archive(
all_segment_files_for_zip,
actual_concat_path,
zip_output_path
)
print(f"Zipping test successful. Output: {actual_zip_path}")
assert actual_zip_path.exists()
# Verify zip contents (basic check)
segments_dir_name = "segments" # Define this for the assertion below
with zipfile.ZipFile(actual_zip_path, 'r') as zf_read:
zip_contents = zf_read.namelist()
print(f"ZIP contents: {zip_contents}")
assert Path(segments_dir_name) / segment1_path.name in [Path(p) for p in zip_contents]
assert Path(segments_dir_name) / segment2_path.name in [Path(p) for p in zip_contents]
assert concatenated_output_path.name in zip_contents
print("\nAll AudioManipulationService tests passed!")
except Exception as e:
import traceback
print(f"\nAn error occurred during AudioManipulationService tests:")
traceback.print_exc()
finally:
# Clean up temporary directory
# shutil.rmtree(test_temp_dir)
# print(f"Cleaned up temporary test directory: {test_temp_dir}")
print(f"Test files are in {test_temp_dir}. Please inspect and delete manually if needed.")