Source code for matchbox.MaskedConv1d

# Modified code : original was taken from the nvidia nemo library https://github.com/NVIDIA/NeMo/blob/41fcf4daccee9aa5082431f4ae89c76fc685eac9/nemo/collections/asr/parts/submodules/jasper.py#L90

import torch.nn as nn
import torch

def _masked_conv_init_lens(lens: torch.Tensor, current_maxlen: int, original_maxlen: torch.Tensor):
    if current_maxlen > original_maxlen:
        new_lens = torch.arange(current_maxlen)
        new_max_lens = torch.tensor(current_maxlen)
    else:
        new_lens = lens
        new_max_lens = original_maxlen
    return new_lens, new_max_lens

[docs]class MaskedConv1d(nn.Module): __constants__ = ["use_conv_mask", "real_out_channels", "heads"] def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, heads=-1, bias=False, use_mask=True ): super(MaskedConv1d, self).__init__() if not (heads == -1 or groups == in_channels): raise ValueError("Only use heads for depthwise convolutions") self.real_out_channels = out_channels if heads != -1: in_channels = heads out_channels = heads groups = heads # preserve original padding self._padding = padding # if padding is a tuple/list, it is considered as asymmetric padding if type(padding) in (tuple, list): self.pad_layer = nn.ConstantPad1d(padding, value=0.0) # reset padding for conv since pad_layer will handle this padding = 0 else: self.pad_layer = None self.conv = nn.Conv1d( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, ) self.use_mask = False self.heads = heads # Calculations for "same" padding cache self.same_padding = (self.conv.stride[0] == 1) and ( 2 * self.conv.padding[0] == self.conv.dilation[0] * (self.conv.kernel_size[0] - 1) ) if self.pad_layer is None: self.same_padding_asymmetric = False else: self.same_padding_asymmetric = (self.conv.stride[0] == 1) and ( sum(self._padding) == self.conv.dilation[0] * (self.conv.kernel_size[0] - 1) ) # `self.lens` caches consecutive integers from 0 to `self.max_len` that are used to compute the mask for a # batch. Recomputed to bigger size as needed. Stored on a device of the latest batch lens. if self.use_mask: self.max_len = torch.tensor(0) self.lens = torch.tensor(0)
[docs] def get_seq_len(self, lens): if self.same_padding or self.same_padding_asymmetric: return lens if self.pad_layer is None: return ( lens + 2 * self.conv.padding[0] - self.conv.dilation[0] * (self.conv.kernel_size[0] - 1) - 1 ) // self.conv.stride[0] + 1 else: return ( lens + sum(self._padding) - self.conv.dilation[0] * (self.conv.kernel_size[0] - 1) - 1 ) // self.conv.stride[0] + 1
[docs] def forward(self, x, lens): if self.use_mask: x, lens = self.update_masked_length(x, lens) # asymmtric pad if necessary if self.pad_layer is not None: x = self.pad_layer(x) sh = x.shape if self.heads != -1: x = x.view(-1, self.heads, sh[-1]) out = self.conv(x) if self.heads != -1: out = out.view(sh[0], self.real_out_channels, -1) return out, lens
[docs] def update_masked_length(self, x, lens): max_len = x.size(2) self.lens, self.max_len = _masked_conv_init_lens(self.lens, max_len, self.max_len) self.lens = self.lens.to(lens.device) mask = self.lens[:max_len].unsqueeze(0) < lens.unsqueeze(1) x = x * mask.unsqueeze(1).to(device=x.device) lens = self.get_seq_len(lens) return x, lens