import torch import torchaudio from pathlib import Path from typing import List, Dict, Union, Tuple import zipfile # Define a common sample rate, e.g., from the TTS model. This should ideally be configurable or dynamically obtained. # For now, let's assume the TTS model (ChatterboxTTS) outputs at a known sample rate. # The ChatterboxTTS model.sr is 24000. DEFAULT_SAMPLE_RATE = 24000 class AudioManipulationService: def __init__(self, default_sample_rate: int = DEFAULT_SAMPLE_RATE): self.sample_rate = default_sample_rate def _load_audio(self, file_path: Union[str, Path]) -> Tuple[torch.Tensor, int]: """Loads an audio file and returns the waveform and sample rate.""" try: waveform, sr = torchaudio.load(file_path) return waveform, sr except Exception as e: raise RuntimeError(f"Error loading audio file {file_path}: {e}") def _create_silence(self, duration_seconds: float) -> torch.Tensor: """Creates a silent audio tensor of a given duration.""" num_frames = int(duration_seconds * self.sample_rate) return torch.zeros((1, num_frames)) # Mono silence def concatenate_audio_segments( self, segment_results: List[Dict], output_concatenated_path: Path ) -> Path: """ Concatenates audio segments and silences into a single audio file. Args: segment_results: A list of dictionaries, where each dict represents an audio segment or a silence. Expected format: For speech: {'type': 'speech', 'path': 'path/to/audio.wav', ...} For silence: {'type': 'silence', 'duration': 0.5, ...} output_concatenated_path: The path to save the final concatenated audio file. Returns: The path to the concatenated audio file. """ all_waveforms: List[torch.Tensor] = [] current_sample_rate = self.sample_rate # Assume this initially, verify with first loaded audio for i, segment_info in enumerate(segment_results): segment_type = segment_info.get("type") if segment_type == "speech": audio_path_str = segment_info.get("path") if not audio_path_str: print(f"Warning: Speech segment {i} has no path. Skipping.") continue audio_path = Path(audio_path_str) if not audio_path.exists(): print(f"Warning: Audio file {audio_path} for segment {i} not found. Skipping.") continue try: waveform, sr = self._load_audio(audio_path) # Ensure consistent sample rate. Resample if necessary. # For simplicity, this example assumes all inputs will match self.sample_rate # or the first loaded audio's sample rate. A more robust implementation # would resample if sr != current_sample_rate. if i == 0 and not all_waveforms: # First audio segment sets the reference SR if not default current_sample_rate = sr if sr != self.sample_rate: print(f"Warning: First audio segment SR ({sr} Hz) differs from service default SR ({self.sample_rate} Hz). Using segment SR.") if sr != current_sample_rate: print(f"Warning: Sample rate mismatch for {audio_path} ({sr} Hz) vs expected ({current_sample_rate} Hz). Resampling...") resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=current_sample_rate) waveform = resampler(waveform) # Ensure mono. If stereo, take the mean or first channel. if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) all_waveforms.append(waveform) except Exception as e: print(f"Error processing speech segment {audio_path}: {e}. Skipping.") elif segment_type == "silence": duration = segment_info.get("duration") if duration is None or not isinstance(duration, (int, float)) or duration < 0: print(f"Warning: Silence segment {i} has invalid duration. Skipping.") continue silence_waveform = self._create_silence(float(duration)) all_waveforms.append(silence_waveform) elif segment_type == "error": # Errors are already logged by DialogProcessorService, just skip here. print(f"Skipping segment {i} due to previous error: {segment_info.get('message')}") continue else: print(f"Warning: Unknown segment type '{segment_type}' at index {i}. Skipping.") if not all_waveforms: raise ValueError("No valid audio segments or silences found to concatenate.") # Concatenate all waveforms final_waveform = torch.cat(all_waveforms, dim=1) # Ensure output directory exists output_concatenated_path.parent.mkdir(parents=True, exist_ok=True) # Save the concatenated audio try: torchaudio.save(str(output_concatenated_path), final_waveform, current_sample_rate) print(f"Concatenated audio saved to: {output_concatenated_path}") return output_concatenated_path except Exception as e: raise RuntimeError(f"Error saving concatenated audio to {output_concatenated_path}: {e}") def create_zip_archive( self, segment_file_paths: List[Path], concatenated_audio_path: Path, output_zip_path: Path ) -> Path: """ Creates a ZIP archive containing individual audio segments and the concatenated audio file. Args: segment_file_paths: A list of paths to the individual audio segment files. concatenated_audio_path: Path to the final concatenated audio file. output_zip_path: The path to save the output ZIP archive. Returns: The path to the created ZIP archive. """ output_zip_path.parent.mkdir(parents=True, exist_ok=True) with zipfile.ZipFile(output_zip_path, 'w', zipfile.ZIP_DEFLATED) as zf: # Add concatenated audio if concatenated_audio_path.exists(): zf.write(concatenated_audio_path, arcname=concatenated_audio_path.name) else: print(f"Warning: Concatenated audio file {concatenated_audio_path} not found for zipping.") # Add individual segments segments_dir_name = "segments" for file_path in segment_file_paths: if file_path.exists() and file_path.is_file(): # Store segments in a subdirectory within the zip for organization zf.write(file_path, arcname=Path(segments_dir_name) / file_path.name) else: print(f"Warning: Segment file {file_path} not found or is not a file. Skipping for zipping.") print(f"ZIP archive created at: {output_zip_path}") return output_zip_path # Example Usage (Test Block) if __name__ == "__main__": import tempfile import shutil # Create a temporary directory for test files test_temp_dir = Path(tempfile.mkdtemp(prefix="audio_manip_test_")) print(f"Created temporary test directory: {test_temp_dir}") # Instance of the service audio_service = AudioManipulationService() # --- Test Data Setup --- # Create dummy audio files (e.g., short silences with different names) dummy_sr = audio_service.sample_rate segment1_path = test_temp_dir / "segment1_speech.wav" segment2_path = test_temp_dir / "segment2_speech.wav" torchaudio.save(str(segment1_path), audio_service._create_silence(1.0), dummy_sr) # Create a dummy segment with a different sample rate to test resampling dummy_sr_alt = 16000 temp_waveform_alt_sr = torch.rand((1, int(0.5 * dummy_sr_alt))) # 0.5s at 16kHz torchaudio.save(str(segment2_path), temp_waveform_alt_sr, dummy_sr_alt) segment_results_for_concat = [ {"type": "speech", "path": str(segment1_path), "speaker_id": "spk1", "text_chunk": "Test 1"}, {"type": "silence", "duration": 0.5}, {"type": "speech", "path": str(segment2_path), "speaker_id": "spk2", "text_chunk": "Test 2 (alt SR)"}, {"type": "error", "message": "Simulated error, should be skipped"}, {"type": "speech", "path": "non_existent_segment.wav"}, # Test non-existent file {"type": "silence", "duration": -0.2} # Test invalid duration ] concatenated_output_path = test_temp_dir / "final_concatenated_audio.wav" zip_output_path = test_temp_dir / "audio_archive.zip" all_segment_files_for_zip = [segment1_path, segment2_path] try: # Test concatenation print("\n--- Testing Concatenation ---") actual_concat_path = audio_service.concatenate_audio_segments( segment_results_for_concat, concatenated_output_path ) print(f"Concatenation test successful. Output: {actual_concat_path}") assert actual_concat_path.exists() # Basic check: load concatenated and verify duration (approx) concat_wav, concat_sr = audio_service._load_audio(actual_concat_path) expected_duration = 1.0 + 0.5 + 0.5 # seg1 (1.0s) + silence (0.5s) + seg2 (0.5s) = 2.0s actual_duration = concat_wav.shape[1] / concat_sr print(f"Expected duration (approx): {expected_duration}s, Actual duration: {actual_duration:.2f}s") assert abs(actual_duration - expected_duration) < 0.1 # Allow small deviation # Test Zipping print("\n--- Testing Zipping ---") actual_zip_path = audio_service.create_zip_archive( all_segment_files_for_zip, actual_concat_path, zip_output_path ) print(f"Zipping test successful. Output: {actual_zip_path}") assert actual_zip_path.exists() # Verify zip contents (basic check) segments_dir_name = "segments" # Define this for the assertion below with zipfile.ZipFile(actual_zip_path, 'r') as zf_read: zip_contents = zf_read.namelist() print(f"ZIP contents: {zip_contents}") assert Path(segments_dir_name) / segment1_path.name in [Path(p) for p in zip_contents] assert Path(segments_dir_name) / segment2_path.name in [Path(p) for p in zip_contents] assert concatenated_output_path.name in zip_contents print("\nAll AudioManipulationService tests passed!") except Exception as e: import traceback print(f"\nAn error occurred during AudioManipulationService tests:") traceback.print_exc() finally: # Clean up temporary directory # shutil.rmtree(test_temp_dir) # print(f"Cleaned up temporary test directory: {test_temp_dir}") print(f"Test files are in {test_temp_dir}. Please inspect and delete manually if needed.")