mirror of
https://github.com/index-tts/index-tts.git
synced 2025-11-28 18:30:25 +08:00
546 lines
16 KiB
Python
546 lines
16 KiB
Python
"""Library implementing convolutional neural networks.
|
|
|
|
Authors
|
|
* Mirco Ravanelli 2020
|
|
* Jianyuan Zhong 2020
|
|
* Cem Subakan 2021
|
|
* Davide Borra 2021
|
|
* Andreas Nautsch 2022
|
|
* Sarthak Yadav 2022
|
|
"""
|
|
|
|
import logging
|
|
import math
|
|
from typing import Tuple
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torchaudio
|
|
|
|
|
|
class SincConv(nn.Module):
|
|
"""This function implements SincConv (SincNet).
|
|
|
|
M. Ravanelli, Y. Bengio, "Speaker Recognition from raw waveform with
|
|
SincNet", in Proc. of SLT 2018 (https://arxiv.org/abs/1808.00158)
|
|
|
|
Arguments
|
|
---------
|
|
out_channels : int
|
|
It is the number of output channels.
|
|
kernel_size: int
|
|
Kernel size of the convolutional filters.
|
|
input_shape : tuple
|
|
The shape of the input. Alternatively use ``in_channels``.
|
|
in_channels : int
|
|
The number of input channels. Alternatively use ``input_shape``.
|
|
stride : int
|
|
Stride factor of the convolutional filters. When the stride factor > 1,
|
|
a decimation in time is performed.
|
|
dilation : int
|
|
Dilation factor of the convolutional filters.
|
|
padding : str
|
|
(same, valid, causal). If "valid", no padding is performed.
|
|
If "same" and stride is 1, output shape is the same as the input shape.
|
|
"causal" results in causal (dilated) convolutions.
|
|
padding_mode : str
|
|
This flag specifies the type of padding. See torch.nn documentation
|
|
for more information.
|
|
sample_rate : int
|
|
Sampling rate of the input signals. It is only used for sinc_conv.
|
|
min_low_hz : float
|
|
Lowest possible frequency (in Hz) for a filter. It is only used for
|
|
sinc_conv.
|
|
min_band_hz : float
|
|
Lowest possible value (in Hz) for a filter bandwidth.
|
|
|
|
Example
|
|
-------
|
|
>>> inp_tensor = torch.rand([10, 16000])
|
|
>>> conv = SincConv(input_shape=inp_tensor.shape, out_channels=25, kernel_size=11)
|
|
>>> out_tensor = conv(inp_tensor)
|
|
>>> out_tensor.shape
|
|
torch.Size([10, 16000, 25])
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
out_channels,
|
|
kernel_size,
|
|
input_shape=None,
|
|
in_channels=None,
|
|
stride=1,
|
|
dilation=1,
|
|
padding="same",
|
|
padding_mode="reflect",
|
|
sample_rate=16000,
|
|
min_low_hz=50,
|
|
min_band_hz=50,
|
|
):
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.kernel_size = kernel_size
|
|
self.stride = stride
|
|
self.dilation = dilation
|
|
self.padding = padding
|
|
self.padding_mode = padding_mode
|
|
self.sample_rate = sample_rate
|
|
self.min_low_hz = min_low_hz
|
|
self.min_band_hz = min_band_hz
|
|
|
|
# input shape inference
|
|
if input_shape is None and self.in_channels is None:
|
|
raise ValueError("Must provide one of input_shape or in_channels")
|
|
|
|
if self.in_channels is None:
|
|
self.in_channels = self._check_input_shape(input_shape)
|
|
|
|
if self.out_channels % self.in_channels != 0:
|
|
raise ValueError(
|
|
"Number of output channels must be divisible by in_channels"
|
|
)
|
|
|
|
# Initialize Sinc filters
|
|
self._init_sinc_conv()
|
|
|
|
def forward(self, x):
|
|
"""Returns the output of the convolution.
|
|
|
|
Arguments
|
|
---------
|
|
x : torch.Tensor (batch, time, channel)
|
|
input to convolve. 2d or 4d tensors are expected.
|
|
|
|
Returns
|
|
-------
|
|
wx : torch.Tensor
|
|
The convolved outputs.
|
|
"""
|
|
x = x.transpose(1, -1)
|
|
self.device = x.device
|
|
|
|
unsqueeze = x.ndim == 2
|
|
if unsqueeze:
|
|
x = x.unsqueeze(1)
|
|
|
|
if self.padding == "same":
|
|
x = self._manage_padding(
|
|
x, self.kernel_size, self.dilation, self.stride
|
|
)
|
|
|
|
elif self.padding == "causal":
|
|
num_pad = (self.kernel_size - 1) * self.dilation
|
|
x = F.pad(x, (num_pad, 0))
|
|
|
|
elif self.padding == "valid":
|
|
pass
|
|
|
|
else:
|
|
raise ValueError(
|
|
"Padding must be 'same', 'valid' or 'causal'. Got %s."
|
|
% (self.padding)
|
|
)
|
|
|
|
sinc_filters = self._get_sinc_filters()
|
|
|
|
wx = F.conv1d(
|
|
x,
|
|
sinc_filters,
|
|
stride=self.stride,
|
|
padding=0,
|
|
dilation=self.dilation,
|
|
groups=self.in_channels,
|
|
)
|
|
|
|
if unsqueeze:
|
|
wx = wx.squeeze(1)
|
|
|
|
wx = wx.transpose(1, -1)
|
|
|
|
return wx
|
|
|
|
def _check_input_shape(self, shape):
|
|
"""Checks the input shape and returns the number of input channels."""
|
|
|
|
if len(shape) == 2:
|
|
in_channels = 1
|
|
elif len(shape) == 3:
|
|
in_channels = shape[-1]
|
|
else:
|
|
raise ValueError(
|
|
"sincconv expects 2d or 3d inputs. Got " + str(len(shape))
|
|
)
|
|
|
|
# Kernel size must be odd
|
|
if self.kernel_size % 2 == 0:
|
|
raise ValueError(
|
|
"The field kernel size must be an odd number. Got %s."
|
|
% (self.kernel_size)
|
|
)
|
|
return in_channels
|
|
|
|
def _get_sinc_filters(self):
|
|
"""This functions creates the sinc-filters to used for sinc-conv."""
|
|
# Computing the low frequencies of the filters
|
|
low = self.min_low_hz + torch.abs(self.low_hz_)
|
|
|
|
# Setting minimum band and minimum freq
|
|
high = torch.clamp(
|
|
low + self.min_band_hz + torch.abs(self.band_hz_),
|
|
self.min_low_hz,
|
|
self.sample_rate / 2,
|
|
)
|
|
band = (high - low)[:, 0]
|
|
|
|
# Passing from n_ to the corresponding f_times_t domain
|
|
self.n_ = self.n_.to(self.device)
|
|
self.window_ = self.window_.to(self.device)
|
|
f_times_t_low = torch.matmul(low, self.n_)
|
|
f_times_t_high = torch.matmul(high, self.n_)
|
|
|
|
# Left part of the filters.
|
|
band_pass_left = (
|
|
(torch.sin(f_times_t_high) - torch.sin(f_times_t_low))
|
|
/ (self.n_ / 2)
|
|
) * self.window_
|
|
|
|
# Central element of the filter
|
|
band_pass_center = 2 * band.view(-1, 1)
|
|
|
|
# Right part of the filter (sinc filters are symmetric)
|
|
band_pass_right = torch.flip(band_pass_left, dims=[1])
|
|
|
|
# Combining left, central, and right part of the filter
|
|
band_pass = torch.cat(
|
|
[band_pass_left, band_pass_center, band_pass_right], dim=1
|
|
)
|
|
|
|
# Amplitude normalization
|
|
band_pass = band_pass / (2 * band[:, None])
|
|
|
|
# Setting up the filter coefficients
|
|
filters = band_pass.view(self.out_channels, 1, self.kernel_size)
|
|
|
|
return filters
|
|
|
|
def _init_sinc_conv(self):
|
|
"""Initializes the parameters of the sinc_conv layer."""
|
|
|
|
# Initialize filterbanks such that they are equally spaced in Mel scale
|
|
high_hz = self.sample_rate / 2 - (self.min_low_hz + self.min_band_hz)
|
|
|
|
mel = torch.linspace(
|
|
self._to_mel(self.min_low_hz),
|
|
self._to_mel(high_hz),
|
|
self.out_channels + 1,
|
|
)
|
|
|
|
hz = self._to_hz(mel)
|
|
|
|
# Filter lower frequency and bands
|
|
self.low_hz_ = hz[:-1].unsqueeze(1)
|
|
self.band_hz_ = (hz[1:] - hz[:-1]).unsqueeze(1)
|
|
|
|
# Maiking freq and bands learnable
|
|
self.low_hz_ = nn.Parameter(self.low_hz_)
|
|
self.band_hz_ = nn.Parameter(self.band_hz_)
|
|
|
|
# Hamming window
|
|
n_lin = torch.linspace(
|
|
0, (self.kernel_size / 2) - 1, steps=int((self.kernel_size / 2))
|
|
)
|
|
self.window_ = 0.54 - 0.46 * torch.cos(
|
|
2 * math.pi * n_lin / self.kernel_size
|
|
)
|
|
|
|
# Time axis (only half is needed due to symmetry)
|
|
n = (self.kernel_size - 1) / 2.0
|
|
self.n_ = (
|
|
2 * math.pi * torch.arange(-n, 0).view(1, -1) / self.sample_rate
|
|
)
|
|
|
|
def _to_mel(self, hz):
|
|
"""Converts frequency in Hz to the mel scale."""
|
|
return 2595 * np.log10(1 + hz / 700)
|
|
|
|
def _to_hz(self, mel):
|
|
"""Converts frequency in the mel scale to Hz."""
|
|
return 700 * (10 ** (mel / 2595) - 1)
|
|
|
|
def _manage_padding(self, x, kernel_size: int, dilation: int, stride: int):
|
|
"""This function performs zero-padding on the time axis
|
|
such that their lengths is unchanged after the convolution.
|
|
|
|
Arguments
|
|
---------
|
|
x : torch.Tensor
|
|
Input tensor.
|
|
kernel_size : int
|
|
Size of kernel.
|
|
dilation : int
|
|
Dilation used.
|
|
stride : int
|
|
Stride.
|
|
|
|
Returns
|
|
-------
|
|
x : torch.Tensor
|
|
"""
|
|
|
|
# Detecting input shape
|
|
L_in = self.in_channels
|
|
|
|
# Time padding
|
|
padding = get_padding_elem(L_in, stride, kernel_size, dilation)
|
|
|
|
# Applying padding
|
|
x = F.pad(x, padding, mode=self.padding_mode)
|
|
|
|
return x
|
|
|
|
|
|
class Conv1d(nn.Module):
|
|
"""This function implements 1d convolution.
|
|
|
|
Arguments
|
|
---------
|
|
out_channels : int
|
|
It is the number of output channels.
|
|
kernel_size : int
|
|
Kernel size of the convolutional filters.
|
|
input_shape : tuple
|
|
The shape of the input. Alternatively use ``in_channels``.
|
|
in_channels : int
|
|
The number of input channels. Alternatively use ``input_shape``.
|
|
stride : int
|
|
Stride factor of the convolutional filters. When the stride factor > 1,
|
|
a decimation in time is performed.
|
|
dilation : int
|
|
Dilation factor of the convolutional filters.
|
|
padding : str
|
|
(same, valid, causal). If "valid", no padding is performed.
|
|
If "same" and stride is 1, output shape is the same as the input shape.
|
|
"causal" results in causal (dilated) convolutions.
|
|
groups : int
|
|
Number of blocked connections from input channels to output channels.
|
|
bias : bool
|
|
Whether to add a bias term to convolution operation.
|
|
padding_mode : str
|
|
This flag specifies the type of padding. See torch.nn documentation
|
|
for more information.
|
|
skip_transpose : bool
|
|
If False, uses batch x time x channel convention of speechbrain.
|
|
If True, uses batch x channel x time convention.
|
|
weight_norm : bool
|
|
If True, use weight normalization,
|
|
to be removed with self.remove_weight_norm() at inference
|
|
conv_init : str
|
|
Weight initialization for the convolution network
|
|
default_padding: str or int
|
|
This sets the default padding mode that will be used by the pytorch Conv1d backend.
|
|
|
|
Example
|
|
-------
|
|
>>> inp_tensor = torch.rand([10, 40, 16])
|
|
>>> cnn_1d = Conv1d(
|
|
... input_shape=inp_tensor.shape, out_channels=8, kernel_size=5
|
|
... )
|
|
>>> out_tensor = cnn_1d(inp_tensor)
|
|
>>> out_tensor.shape
|
|
torch.Size([10, 40, 8])
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
out_channels,
|
|
kernel_size,
|
|
input_shape=None,
|
|
in_channels=None,
|
|
stride=1,
|
|
dilation=1,
|
|
padding="same",
|
|
groups=1,
|
|
bias=True,
|
|
padding_mode="reflect",
|
|
skip_transpose=False,
|
|
weight_norm=False,
|
|
conv_init=None,
|
|
default_padding=0,
|
|
):
|
|
super().__init__()
|
|
self.kernel_size = kernel_size
|
|
self.stride = stride
|
|
self.dilation = dilation
|
|
self.padding = padding
|
|
self.padding_mode = padding_mode
|
|
self.unsqueeze = False
|
|
self.skip_transpose = skip_transpose
|
|
|
|
if input_shape is None and in_channels is None:
|
|
raise ValueError("Must provide one of input_shape or in_channels")
|
|
|
|
if in_channels is None:
|
|
in_channels = self._check_input_shape(input_shape)
|
|
|
|
self.in_channels = in_channels
|
|
|
|
self.conv = nn.Conv1d(
|
|
in_channels,
|
|
out_channels,
|
|
self.kernel_size,
|
|
stride=self.stride,
|
|
dilation=self.dilation,
|
|
padding=default_padding,
|
|
groups=groups,
|
|
bias=bias,
|
|
)
|
|
|
|
if conv_init == "kaiming":
|
|
nn.init.kaiming_normal_(self.conv.weight)
|
|
elif conv_init == "zero":
|
|
nn.init.zeros_(self.conv.weight)
|
|
elif conv_init == "normal":
|
|
nn.init.normal_(self.conv.weight, std=1e-6)
|
|
|
|
if weight_norm:
|
|
self.conv = nn.utils.weight_norm(self.conv)
|
|
|
|
def forward(self, x):
|
|
"""Returns the output of the convolution.
|
|
|
|
Arguments
|
|
---------
|
|
x : torch.Tensor (batch, time, channel)
|
|
input to convolve. 2d or 4d tensors are expected.
|
|
|
|
Returns
|
|
-------
|
|
wx : torch.Tensor
|
|
The convolved outputs.
|
|
"""
|
|
if not self.skip_transpose:
|
|
x = x.transpose(1, -1)
|
|
|
|
if self.unsqueeze:
|
|
x = x.unsqueeze(1)
|
|
|
|
if self.padding == "same":
|
|
x = self._manage_padding(
|
|
x, self.kernel_size, self.dilation, self.stride
|
|
)
|
|
|
|
elif self.padding == "causal":
|
|
num_pad = (self.kernel_size - 1) * self.dilation
|
|
x = F.pad(x, (num_pad, 0))
|
|
|
|
elif self.padding == "valid":
|
|
pass
|
|
|
|
else:
|
|
raise ValueError(
|
|
"Padding must be 'same', 'valid' or 'causal'. Got "
|
|
+ self.padding
|
|
)
|
|
|
|
wx = self.conv(x)
|
|
|
|
if self.unsqueeze:
|
|
wx = wx.squeeze(1)
|
|
|
|
if not self.skip_transpose:
|
|
wx = wx.transpose(1, -1)
|
|
|
|
return wx
|
|
|
|
def _manage_padding(self, x, kernel_size: int, dilation: int, stride: int):
|
|
"""This function performs zero-padding on the time axis
|
|
such that their lengths is unchanged after the convolution.
|
|
|
|
Arguments
|
|
---------
|
|
x : torch.Tensor
|
|
Input tensor.
|
|
kernel_size : int
|
|
Size of kernel.
|
|
dilation : int
|
|
Dilation used.
|
|
stride : int
|
|
Stride.
|
|
|
|
Returns
|
|
-------
|
|
x : torch.Tensor
|
|
The padded outputs.
|
|
"""
|
|
|
|
# Detecting input shape
|
|
L_in = self.in_channels
|
|
|
|
# Time padding
|
|
padding = get_padding_elem(L_in, stride, kernel_size, dilation)
|
|
|
|
# Applying padding
|
|
x = F.pad(x, padding, mode=self.padding_mode)
|
|
|
|
return x
|
|
|
|
def _check_input_shape(self, shape):
|
|
"""Checks the input shape and returns the number of input channels."""
|
|
|
|
if len(shape) == 2:
|
|
self.unsqueeze = True
|
|
in_channels = 1
|
|
elif self.skip_transpose:
|
|
in_channels = shape[1]
|
|
elif len(shape) == 3:
|
|
in_channels = shape[2]
|
|
else:
|
|
raise ValueError(
|
|
"conv1d expects 2d, 3d inputs. Got " + str(len(shape))
|
|
)
|
|
|
|
# Kernel size must be odd
|
|
if not self.padding == "valid" and self.kernel_size % 2 == 0:
|
|
raise ValueError(
|
|
"The field kernel size must be an odd number. Got %s."
|
|
% (self.kernel_size)
|
|
)
|
|
|
|
return in_channels
|
|
|
|
def remove_weight_norm(self):
|
|
"""Removes weight normalization at inference if used during training."""
|
|
self.conv = nn.utils.remove_weight_norm(self.conv)
|
|
|
|
|
|
def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
|
|
"""This function computes the number of elements to add for zero-padding.
|
|
|
|
Arguments
|
|
---------
|
|
L_in : int
|
|
stride: int
|
|
kernel_size : int
|
|
dilation : int
|
|
|
|
Returns
|
|
-------
|
|
padding : int
|
|
The size of the padding to be added
|
|
"""
|
|
if stride > 1:
|
|
padding = [math.floor(kernel_size / 2), math.floor(kernel_size / 2)]
|
|
|
|
else:
|
|
L_out = (
|
|
math.floor((L_in - dilation * (kernel_size - 1) - 1) / stride) + 1
|
|
)
|
|
padding = [
|
|
math.floor((L_in - L_out) / 2),
|
|
math.floor((L_in - L_out) / 2),
|
|
]
|
|
return padding
|
|
|