242 lines
11 KiB
Python
242 lines
11 KiB
Python
import torch
|
|
import torchaudio
|
|
from pathlib import Path
|
|
from typing import List, Dict, Union, Tuple
|
|
import zipfile
|
|
|
|
# Define a common sample rate, e.g., from the TTS model. This should ideally be configurable or dynamically obtained.
|
|
# For now, let's assume the TTS model (ChatterboxTTS) outputs at a known sample rate.
|
|
# The ChatterboxTTS model.sr is 24000.
|
|
DEFAULT_SAMPLE_RATE = 24000
|
|
|
|
class AudioManipulationService:
|
|
def __init__(self, default_sample_rate: int = DEFAULT_SAMPLE_RATE):
|
|
self.sample_rate = default_sample_rate
|
|
|
|
def _load_audio(self, file_path: Union[str, Path]) -> Tuple[torch.Tensor, int]:
|
|
"""Loads an audio file and returns the waveform and sample rate."""
|
|
try:
|
|
waveform, sr = torchaudio.load(file_path)
|
|
return waveform, sr
|
|
except Exception as e:
|
|
raise RuntimeError(f"Error loading audio file {file_path}: {e}")
|
|
|
|
def _create_silence(self, duration_seconds: float) -> torch.Tensor:
|
|
"""Creates a silent audio tensor of a given duration."""
|
|
num_frames = int(duration_seconds * self.sample_rate)
|
|
return torch.zeros((1, num_frames)) # Mono silence
|
|
|
|
def concatenate_audio_segments(
|
|
self,
|
|
segment_results: List[Dict],
|
|
output_concatenated_path: Path
|
|
) -> Path:
|
|
"""
|
|
Concatenates audio segments and silences into a single audio file.
|
|
|
|
Args:
|
|
segment_results: A list of dictionaries, where each dict represents an audio
|
|
segment or a silence. Expected format:
|
|
For speech: {'type': 'speech', 'path': 'path/to/audio.wav', ...}
|
|
For silence: {'type': 'silence', 'duration': 0.5, ...}
|
|
output_concatenated_path: The path to save the final concatenated audio file.
|
|
|
|
Returns:
|
|
The path to the concatenated audio file.
|
|
"""
|
|
all_waveforms: List[torch.Tensor] = []
|
|
current_sample_rate = self.sample_rate # Assume this initially, verify with first loaded audio
|
|
|
|
for i, segment_info in enumerate(segment_results):
|
|
segment_type = segment_info.get("type")
|
|
|
|
if segment_type == "speech":
|
|
audio_path_str = segment_info.get("path")
|
|
if not audio_path_str:
|
|
print(f"Warning: Speech segment {i} has no path. Skipping.")
|
|
continue
|
|
|
|
audio_path = Path(audio_path_str)
|
|
if not audio_path.exists():
|
|
print(f"Warning: Audio file {audio_path} for segment {i} not found. Skipping.")
|
|
continue
|
|
|
|
try:
|
|
waveform, sr = self._load_audio(audio_path)
|
|
# Ensure consistent sample rate. Resample if necessary.
|
|
# For simplicity, this example assumes all inputs will match self.sample_rate
|
|
# or the first loaded audio's sample rate. A more robust implementation
|
|
# would resample if sr != current_sample_rate.
|
|
if i == 0 and not all_waveforms: # First audio segment sets the reference SR if not default
|
|
current_sample_rate = sr
|
|
if sr != self.sample_rate:
|
|
print(f"Warning: First audio segment SR ({sr} Hz) differs from service default SR ({self.sample_rate} Hz). Using segment SR.")
|
|
|
|
if sr != current_sample_rate:
|
|
print(f"Warning: Sample rate mismatch for {audio_path} ({sr} Hz) vs expected ({current_sample_rate} Hz). Resampling...")
|
|
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=current_sample_rate)
|
|
waveform = resampler(waveform)
|
|
|
|
# Ensure mono. If stereo, take the mean or first channel.
|
|
if waveform.shape[0] > 1:
|
|
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
|
|
|
all_waveforms.append(waveform)
|
|
except Exception as e:
|
|
print(f"Error processing speech segment {audio_path}: {e}. Skipping.")
|
|
|
|
elif segment_type == "silence":
|
|
duration = segment_info.get("duration")
|
|
if duration is None or not isinstance(duration, (int, float)) or duration < 0:
|
|
print(f"Warning: Silence segment {i} has invalid duration. Skipping.")
|
|
continue
|
|
silence_waveform = self._create_silence(float(duration))
|
|
all_waveforms.append(silence_waveform)
|
|
|
|
elif segment_type == "error":
|
|
# Errors are already logged by DialogProcessorService, just skip here.
|
|
print(f"Skipping segment {i} due to previous error: {segment_info.get('message')}")
|
|
continue
|
|
|
|
else:
|
|
print(f"Warning: Unknown segment type '{segment_type}' at index {i}. Skipping.")
|
|
|
|
if not all_waveforms:
|
|
raise ValueError("No valid audio segments or silences found to concatenate.")
|
|
|
|
# Concatenate all waveforms
|
|
final_waveform = torch.cat(all_waveforms, dim=1)
|
|
|
|
# Ensure output directory exists
|
|
output_concatenated_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Save the concatenated audio
|
|
try:
|
|
torchaudio.save(str(output_concatenated_path), final_waveform, current_sample_rate)
|
|
print(f"Concatenated audio saved to: {output_concatenated_path}")
|
|
return output_concatenated_path
|
|
except Exception as e:
|
|
raise RuntimeError(f"Error saving concatenated audio to {output_concatenated_path}: {e}")
|
|
|
|
def create_zip_archive(
|
|
self,
|
|
segment_file_paths: List[Path],
|
|
concatenated_audio_path: Path,
|
|
output_zip_path: Path
|
|
) -> Path:
|
|
"""
|
|
Creates a ZIP archive containing individual audio segments and the concatenated audio file.
|
|
|
|
Args:
|
|
segment_file_paths: A list of paths to the individual audio segment files.
|
|
concatenated_audio_path: Path to the final concatenated audio file.
|
|
output_zip_path: The path to save the output ZIP archive.
|
|
|
|
Returns:
|
|
The path to the created ZIP archive.
|
|
"""
|
|
output_zip_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
with zipfile.ZipFile(output_zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
|
|
# Add concatenated audio
|
|
if concatenated_audio_path.exists():
|
|
zf.write(concatenated_audio_path, arcname=concatenated_audio_path.name)
|
|
else:
|
|
print(f"Warning: Concatenated audio file {concatenated_audio_path} not found for zipping.")
|
|
|
|
# Add individual segments
|
|
segments_dir_name = "segments"
|
|
for file_path in segment_file_paths:
|
|
if file_path.exists() and file_path.is_file():
|
|
# Store segments in a subdirectory within the zip for organization
|
|
zf.write(file_path, arcname=Path(segments_dir_name) / file_path.name)
|
|
else:
|
|
print(f"Warning: Segment file {file_path} not found or is not a file. Skipping for zipping.")
|
|
|
|
print(f"ZIP archive created at: {output_zip_path}")
|
|
return output_zip_path
|
|
|
|
# Example Usage (Test Block)
|
|
if __name__ == "__main__":
|
|
import tempfile
|
|
import shutil
|
|
|
|
# Create a temporary directory for test files
|
|
test_temp_dir = Path(tempfile.mkdtemp(prefix="audio_manip_test_"))
|
|
print(f"Created temporary test directory: {test_temp_dir}")
|
|
|
|
# Instance of the service
|
|
audio_service = AudioManipulationService()
|
|
|
|
# --- Test Data Setup ---
|
|
# Create dummy audio files (e.g., short silences with different names)
|
|
dummy_sr = audio_service.sample_rate
|
|
segment1_path = test_temp_dir / "segment1_speech.wav"
|
|
segment2_path = test_temp_dir / "segment2_speech.wav"
|
|
|
|
torchaudio.save(str(segment1_path), audio_service._create_silence(1.0), dummy_sr)
|
|
# Create a dummy segment with a different sample rate to test resampling
|
|
dummy_sr_alt = 16000
|
|
temp_waveform_alt_sr = torch.rand((1, int(0.5 * dummy_sr_alt))) # 0.5s at 16kHz
|
|
torchaudio.save(str(segment2_path), temp_waveform_alt_sr, dummy_sr_alt)
|
|
|
|
segment_results_for_concat = [
|
|
{"type": "speech", "path": str(segment1_path), "speaker_id": "spk1", "text_chunk": "Test 1"},
|
|
{"type": "silence", "duration": 0.5},
|
|
{"type": "speech", "path": str(segment2_path), "speaker_id": "spk2", "text_chunk": "Test 2 (alt SR)"},
|
|
{"type": "error", "message": "Simulated error, should be skipped"},
|
|
{"type": "speech", "path": "non_existent_segment.wav"}, # Test non-existent file
|
|
{"type": "silence", "duration": -0.2} # Test invalid duration
|
|
]
|
|
|
|
concatenated_output_path = test_temp_dir / "final_concatenated_audio.wav"
|
|
zip_output_path = test_temp_dir / "audio_archive.zip"
|
|
|
|
all_segment_files_for_zip = [segment1_path, segment2_path]
|
|
|
|
try:
|
|
# Test concatenation
|
|
print("\n--- Testing Concatenation ---")
|
|
actual_concat_path = audio_service.concatenate_audio_segments(
|
|
segment_results_for_concat,
|
|
concatenated_output_path
|
|
)
|
|
print(f"Concatenation test successful. Output: {actual_concat_path}")
|
|
assert actual_concat_path.exists()
|
|
# Basic check: load concatenated and verify duration (approx)
|
|
concat_wav, concat_sr = audio_service._load_audio(actual_concat_path)
|
|
expected_duration = 1.0 + 0.5 + 0.5 # seg1 (1.0s) + silence (0.5s) + seg2 (0.5s) = 2.0s
|
|
actual_duration = concat_wav.shape[1] / concat_sr
|
|
print(f"Expected duration (approx): {expected_duration}s, Actual duration: {actual_duration:.2f}s")
|
|
assert abs(actual_duration - expected_duration) < 0.1 # Allow small deviation
|
|
|
|
# Test Zipping
|
|
print("\n--- Testing Zipping ---")
|
|
actual_zip_path = audio_service.create_zip_archive(
|
|
all_segment_files_for_zip,
|
|
actual_concat_path,
|
|
zip_output_path
|
|
)
|
|
print(f"Zipping test successful. Output: {actual_zip_path}")
|
|
assert actual_zip_path.exists()
|
|
# Verify zip contents (basic check)
|
|
segments_dir_name = "segments" # Define this for the assertion below
|
|
with zipfile.ZipFile(actual_zip_path, 'r') as zf_read:
|
|
zip_contents = zf_read.namelist()
|
|
print(f"ZIP contents: {zip_contents}")
|
|
assert Path(segments_dir_name) / segment1_path.name in [Path(p) for p in zip_contents]
|
|
assert Path(segments_dir_name) / segment2_path.name in [Path(p) for p in zip_contents]
|
|
assert concatenated_output_path.name in zip_contents
|
|
|
|
print("\nAll AudioManipulationService tests passed!")
|
|
|
|
except Exception as e:
|
|
import traceback
|
|
print(f"\nAn error occurred during AudioManipulationService tests:")
|
|
traceback.print_exc()
|
|
finally:
|
|
# Clean up temporary directory
|
|
# shutil.rmtree(test_temp_dir)
|
|
# print(f"Cleaned up temporary test directory: {test_temp_dir}")
|
|
print(f"Test files are in {test_temp_dir}. Please inspect and delete manually if needed.")
|