|
|
|
|
|
import torch |
|
|
|
|
|
import math |
|
|
|
|
|
|
|
|
def slice_windows(audio_batch: torch.Tensor, |
|
|
sample_rate: int = 16000, |
|
|
window_size_ms: int = 160, |
|
|
stride_ms: int = 80) -> torch.Tensor: |
|
|
""" |
|
|
Create fixed-size windows with overlap from a batch of audio sequences using vectorized operations. |
|
|
|
|
|
Args: |
|
|
audio_batch: Input audio of shape [batch_size, 1, max_audio_length] |
|
|
sample_rate: Audio sample rate in Hz |
|
|
window_size_ms: Window size in milliseconds |
|
|
stride_ms: Stride size in milliseconds |
|
|
|
|
|
Returns: |
|
|
Tensor of shape [batch_size, num_windows, window_size] |
|
|
""" |
|
|
audio_batch = audio_batch.squeeze(1) |
|
|
batch_size, max_audio_length = audio_batch.shape |
|
|
|
|
|
|
|
|
window_size = int(window_size_ms * sample_rate / 1000) |
|
|
stride = int(stride_ms * sample_rate / 1000) |
|
|
num_windows = ((max_audio_length - window_size) // stride) + 1 |
|
|
|
|
|
|
|
|
offsets = torch.arange(0, window_size, device=audio_batch.device) |
|
|
starts = torch.arange(0, num_windows * stride, stride, device=audio_batch.device) |
|
|
|
|
|
|
|
|
indices = starts.unsqueeze(1) + offsets.unsqueeze(0) |
|
|
|
|
|
|
|
|
valid_indices = indices < max_audio_length |
|
|
indices = torch.minimum(indices, torch.tensor(max_audio_length - 1, device=audio_batch.device)) |
|
|
|
|
|
|
|
|
batch_indices = torch.arange(batch_size, device=audio_batch.device)[:, None, None] |
|
|
|
|
|
|
|
|
windows = audio_batch[batch_indices, indices] |
|
|
|
|
|
|
|
|
windows = windows * valid_indices.float() |
|
|
|
|
|
return windows |
|
|
|
|
|
|
|
|
def large_windows_unfold(audio_batch: torch.Tensor, |
|
|
sample_rate: int = 16000, |
|
|
window_size_ms: int = 3000, |
|
|
stride_ms: int = 250) -> torch.Tensor: |
|
|
""" |
|
|
Alternative implementation using unfold operation for potentially better memory efficiency |
|
|
on very large audio files. |
|
|
Args: |
|
|
audio_batch: Input audio of shape [batch_size, 1, max_audio_length] |
|
|
sample_rate: Audio sample rate in Hz |
|
|
window_size_ms: Window size in milliseconds |
|
|
stride_ms: Stride size in milliseconds |
|
|
Returns: |
|
|
Tensor of shape [batch_size, num_windows, window_size] |
|
|
|
|
|
""" |
|
|
audio_batch = audio_batch.squeeze(1) |
|
|
|
|
|
|
|
|
window_size = int(window_size_ms * sample_rate / 1000) |
|
|
stride = int(stride_ms * sample_rate / 1000) |
|
|
|
|
|
|
|
|
windows = audio_batch.unfold(dimension=1, size=window_size, step=stride) |
|
|
|
|
|
return windows |
|
|
|
|
|
|
|
|
def large_windows_fold(window_logits): |
|
|
""" |
|
|
UNDER CONSTRUCTION |
|
|
Combines predictions from segmented windows using the unfold-based implementation. |
|
|
Args: |
|
|
window_logits: Input audio of shape [batch_size, num_windows, frames, num_phonemes] |
|
|
Returns: |
|
|
Tensor of shape [batch_size, num_windows, window_size] |
|
|
|
|
|
""" |
|
|
audio_batch = audio_batch.squeeze(1) |
|
|
|
|
|
|
|
|
window_size = int(window_size_ms * sample_rate / 1000) |
|
|
stride = int(stride_ms * sample_rate / 1000) |
|
|
|
|
|
|
|
|
windows = audio_batch.unfold(dimension=1, size=window_size, step=stride) |
|
|
|
|
|
return windows |
|
|
|
|
|
def stich_window_predictions(window_logits: torch.Tensor, |
|
|
original_audio_length: int, |
|
|
cnn_output_size: int, |
|
|
sample_rate: int = 16000, |
|
|
window_size_ms: int = 160, |
|
|
stride_ms: int = 80) -> torch.Tensor: |
|
|
""" |
|
|
Efficiently combines predictions from overlapping windows while maintaining the original behavior. Can be used for phoneme logits, embeddings, or CNN outputs features. |
|
|
|
|
|
Args: |
|
|
window_logits: Shape [batch_size, num_windows, frames_per_window, output_dim] |
|
|
original_audio_length: Original audio length in samples |
|
|
cnn_output_size: Number of frames output by CNN for each window |
|
|
sample_rate: Audio sample rate (default 16kHz) |
|
|
window_size_ms: Window size in milliseconds |
|
|
stride_ms: Stride size in milliseconds |
|
|
Returns: |
|
|
Tensor of shape [batch_size, total_frames, output_dim] |
|
|
""" |
|
|
device = window_logits.device |
|
|
batch_size, num_windows, frames_per_window, num_phonemes = window_logits.shape |
|
|
|
|
|
|
|
|
window_size_samples = int(window_size_ms * sample_rate / 1000) |
|
|
stride_samples = int(stride_ms * sample_rate / 1000) |
|
|
num_windows_total = ((original_audio_length - window_size_samples) // stride_samples) + 1 |
|
|
total_frames = ((num_windows_total * cnn_output_size) // 2) |
|
|
stride_frames = frames_per_window // 2 |
|
|
|
|
|
|
|
|
window_weights = torch.cos(torch.linspace(-math.pi/2, math.pi/2, frames_per_window, device=device)) |
|
|
window_weights = window_weights.view(1, frames_per_window, 1) |
|
|
|
|
|
|
|
|
combined = torch.zeros(batch_size, total_frames, num_phonemes, device=device) |
|
|
weight_sum = torch.zeros(batch_size, total_frames, 1, device=device) |
|
|
|
|
|
|
|
|
full_windows = num_windows - 1 |
|
|
if full_windows > 0: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
full_slices = window_logits[:, :full_windows] |
|
|
|
|
|
for i in range(full_windows): |
|
|
start_frame = i * stride_frames |
|
|
end_frame = start_frame + frames_per_window |
|
|
combined[:, start_frame:end_frame] += full_slices[:, i] * window_weights |
|
|
weight_sum[:, start_frame:end_frame] += window_weights |
|
|
|
|
|
|
|
|
if num_windows > 0: |
|
|
start_frame = (num_windows - 1) * stride_frames |
|
|
end_frame = start_frame + frames_per_window |
|
|
|
|
|
if end_frame > total_frames: |
|
|
frames_to_use = total_frames - start_frame |
|
|
window_logits_slice = window_logits[:, -1, :frames_to_use] |
|
|
weights = window_weights[:, :frames_to_use] |
|
|
else: |
|
|
window_logits_slice = window_logits[:, -1] |
|
|
weights = window_weights |
|
|
|
|
|
combined[:, start_frame:start_frame + window_logits_slice.size(1)] += window_logits_slice * weights |
|
|
weight_sum[:, start_frame:start_frame + weights.size(1)] += weights |
|
|
|
|
|
|
|
|
combined = combined / (weight_sum + 1e-8) |
|
|
return combined |
|
|
|
|
|
def stich_window_predictions____non_vectorized(window_logits: torch.Tensor, |
|
|
original_audio_length: int, |
|
|
cnn_output_size, |
|
|
sample_rate: int = 16000, |
|
|
window_size_ms: int = 160, |
|
|
stride_ms: int = 80) -> torch.Tensor: |
|
|
device = window_logits.device |
|
|
batch_size, num_windows, frames_per_window, num_phonemes = window_logits.shape |
|
|
|
|
|
window_size_samples = int(window_size_ms * sample_rate / 1000) |
|
|
stride_samples = int(stride_ms * sample_rate / 1000) |
|
|
|
|
|
|
|
|
num_windows_total = ((original_audio_length - window_size_samples) // stride_samples) + 1 |
|
|
|
|
|
|
|
|
frames_per_window_full = cnn_output_size |
|
|
total_frames = ((num_windows_total * frames_per_window_full) // 2) |
|
|
|
|
|
window_weights = torch.cos(torch.linspace(-math.pi/2, math.pi/2, frames_per_window)) |
|
|
window_weights = window_weights.to(device).view(1, frames_per_window, 1) |
|
|
|
|
|
combined = torch.zeros(batch_size, total_frames, num_phonemes, device=device) |
|
|
weight_sum = torch.zeros(batch_size, total_frames, 1, device=device) |
|
|
|
|
|
stride_frames = frames_per_window // 2 |
|
|
|
|
|
for i in range(num_windows): |
|
|
start_frame = i * stride_frames |
|
|
end_frame = start_frame + frames_per_window |
|
|
|
|
|
if end_frame > total_frames: |
|
|
frames_to_use = total_frames - start_frame |
|
|
window_logits_slice = window_logits[:, i, :frames_to_use] |
|
|
weights = window_weights[:, :frames_to_use] |
|
|
else: |
|
|
window_logits_slice = window_logits[:, i] |
|
|
weights = window_weights |
|
|
|
|
|
combined[:, start_frame:end_frame] += window_logits_slice * weights |
|
|
weight_sum[:, start_frame:end_frame] += weights |
|
|
|
|
|
combined = combined / (weight_sum + 1e-8) |
|
|
return combined |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def calc_spec_len_ext(wav_lens, window_size_ms, stride_ms, sample_rate, frames_per_window, disable_windowing=False, wav_len_max=1*16000): |
|
|
""" |
|
|
Calculate the total number of frames for the whole audio clip, for each clip in the batch. |
|
|
When `disable_windowing=False` then there are two level of windowing, one by the window slicing process and other by the CNN. |
|
|
Input: |
|
|
wav_lens: tensor of real lengths of the audio clips in samples. Shape: [batch_size] |
|
|
Returns: |
|
|
spectral_lens: tensor of total number of frames for each audio clip. Shape: [batch_size] |
|
|
""" |
|
|
|
|
|
if (not disable_windowing): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
frames_per_window = frames_per_window.to(wav_lens.device) |
|
|
window_size_wav = int(window_size_ms * sample_rate / 1000) |
|
|
stride_size_wav = int(stride_ms * sample_rate / 1000) |
|
|
spectral_lens = [] |
|
|
for wav_len in wav_lens: |
|
|
|
|
|
if wav_len <= window_size_wav: |
|
|
|
|
|
|
|
|
num_windows = wav_len.float() / window_size_wav |
|
|
total_frames = torch.ceil(frames_per_window * num_windows).long() |
|
|
else: |
|
|
|
|
|
|
|
|
num_windows = ((wav_len - window_size_wav) // stride_size_wav) + 1 |
|
|
|
|
|
total_frames = ((num_windows * frames_per_window) // 2) |
|
|
|
|
|
if (total_frames < 2): |
|
|
raise Exception("WARN: spectral_len < 2, wav_lens:", wav_len.item(), "output frames:", total_frames.item(), "num_windows:", num_windows.item(), "Expected at least", window_size_ms, "ms", "got", (1000*wav_len.item()/sample_rate), "ms") |
|
|
spectral_lens.append(total_frames) |
|
|
|
|
|
spectral_lens = torch.tensor(spectral_lens, device=wav_lens.device, dtype=torch.long) |
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
frames_per_window = frames_per_window.to(wav_lens.device) |
|
|
wav_len_per_frame = (wav_len_max / frames_per_window).clone().detach().to(wav_lens.device) |
|
|
|
|
|
spectral_lens = torch.tensor([frames_per_window]).repeat(len(wav_lens)).to(wav_lens.device) |
|
|
|
|
|
for wi in range(len(wav_lens)): |
|
|
|
|
|
|
|
|
spectral_lens[wi] = torch.ceil(wav_lens[wi]/wav_len_per_frame) |
|
|
if (spectral_lens[wi] > frames_per_window): |
|
|
raise Exception("WARN: spectral_len > frames_per_window, wav_lens:", spectral_lens[wi], frames_per_window, wav_lens[wi]) |
|
|
|
|
|
return spectral_lens |
|
|
|
|
|
|
|
|
def calc_spec_len_ext_v1(wav_lens, window_size_ms, stride_ms, sample_rate, frames_per_window, disable_windowing=False, wav_len_max=1*16000): |
|
|
""" |
|
|
Calculate the total number of frames for the whole audio clip, for each clip in the batch. |
|
|
Input: |
|
|
wav_lens: tensor of real lengths of the audio clips in samples. Shape: [batch_size] |
|
|
Returns: |
|
|
spectral_lens: tensor of total number of frames for each audio clip. Shape: [batch_size] |
|
|
""" |
|
|
|
|
|
if (not disable_windowing): |
|
|
window_size_samples = int(window_size_ms * sample_rate / 1000) |
|
|
stride_samples = int(stride_ms * sample_rate / 1000) |
|
|
|
|
|
|
|
|
frames_per_window = frames_per_window.to(wav_lens.device) |
|
|
|
|
|
spectral_lens = [] |
|
|
for wav_len in wav_lens: |
|
|
|
|
|
num_windows = ((wav_len - window_size_samples) // stride_samples) + 1 |
|
|
|
|
|
|
|
|
|
|
|
total_frames = ((num_windows * frames_per_window) // 2) |
|
|
|
|
|
spectral_lens.append(total_frames) |
|
|
|
|
|
spectral_lens = torch.tensor(spectral_lens, device=wav_lens.device) |
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
frames_per_window = frames_per_window.to(wav_lens.device) |
|
|
wav_len_per_frame = (wav_len_max / frames_per_window).clone().detach().to(wav_lens.device) |
|
|
|
|
|
spectral_lens = torch.tensor([frames_per_window]).repeat(len(wav_lens)).to(wav_lens.device) |
|
|
|
|
|
for wi in range(len(wav_lens)): |
|
|
|
|
|
|
|
|
spectral_lens[wi] = torch.ceil(wav_lens[wi]/wav_len_per_frame) |
|
|
if (spectral_lens[wi] > frames_per_window): |
|
|
raise Exception("WARN: spectral_len > frames_per_window, wav_lens:", spectral_lens[wi], frames_per_window, wav_lens[wi]) |
|
|
|
|
|
return spectral_lens |
|
|
|
|
|
|
|
|
|