Skip to content

TacotronSTFT

TacotronSTFT

Bases: Module

Source code in training/preprocess/tacotron_stft.py
class TacotronSTFT(Module):
    def __init__(
        self,
        filter_length: int,
        hop_length: int,
        win_length: int,
        n_mel_channels: int,
        sampling_rate: int,
        center: bool,
        mel_fmax: Optional[int],
        mel_fmin: float = 0.0,
    ):
        r"""TacotronSTFT module that computes mel-spectrograms from a batch of waves.

        Args:
            filter_length (int): Length of the filter window.
            hop_length (int): Number of samples between successive frames.
            win_length (int): Size of the STFT window.
            n_mel_channels (int): Number of mel bins.
            sampling_rate (int): Sampling rate of the input waveforms.
            mel_fmin (int or None): Minimum frequency for the mel filter bank.
            mel_fmax (int or None): Maximum frequency for the mel filter bank.
            center (bool): Whether to pad the input signal on both sides.
        """
        super().__init__()

        self.n_mel_channels = n_mel_channels
        self.sampling_rate = sampling_rate
        self.n_fft = filter_length
        self.hop_size = hop_length
        self.win_size = win_length
        self.fmin = mel_fmin
        self.fmax = mel_fmax
        self.center = center

        # Define the mel filterbank
        mel = librosa.filters.mel(
            sr=sampling_rate,
            n_fft=filter_length,
            n_mels=n_mel_channels,
            fmin=mel_fmin,
            fmax=mel_fmax,
        )

        mel_basis = torch.from_numpy(mel).float()

        # Define the Hann window
        hann_window = torch.hann_window(win_length)

        self.register_buffer("mel_basis", mel_basis)
        self.register_buffer("hann_window", hann_window)

    def _spectrogram(self, y: torch.Tensor) -> torch.Tensor:
        r"""Computes the linear spectrogram of a batch of waves.

        Args:
            y (torch.Tensor): Input waveforms.

        Returns:
            torch.Tensor: Linear spectrogram.
        """
        assert torch.min(y.data) >= -1
        assert torch.max(y.data) <= 1

        y = torch.nn.functional.pad(
            y.unsqueeze(1),
            (
                int((self.n_fft - self.hop_size) / 2),
                int((self.n_fft - self.hop_size) / 2),
            ),
            mode="reflect",
        )
        y = y.squeeze(1)
        spec = torch.stft(
            y,
            self.n_fft,
            hop_length=self.hop_size,
            win_length=self.win_size,
            window=self.hann_window,  # type: ignore
            center=self.center,
            pad_mode="reflect",
            normalized=False,
            onesided=True,
            return_complex=True,
        )
        return torch.view_as_real(spec)

    def linear_spectrogram(self, y: torch.Tensor) -> torch.Tensor:
        r"""Computes the linear spectrogram of a batch of waves.

        Args:
            y (torch.Tensor): Input waveforms.

        Returns:
            torch.Tensor: Linear spectrogram.
        """
        spec = self._spectrogram(y)
        return torch.norm(spec, p=2, dim=-1)

    def forward(self, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""Computes mel-spectrograms from a batch of waves.

        Args:
            y (torch.FloatTensor): Input waveforms with shape (B, T) in range [-1, 1]

        Returns:
            torch.FloatTensor: Spectrogram of shape (B, n_spech_channels, T)
            torch.FloatTensor: Mel-spectrogram of shape (B, n_mel_channels, T)
        """
        spec = self._spectrogram(y)

        spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))

        mel = torch.matmul(self.mel_basis, spec)  # type: ignore
        mel = self.spectral_normalize_torch(mel)

        return spec, mel

    def spectral_normalize_torch(self, magnitudes: torch.Tensor) -> torch.Tensor:
        r"""Applies dynamic range compression to magnitudes.

        Args:
            magnitudes (torch.Tensor): Input magnitudes.

        Returns:
            torch.Tensor: Output magnitudes.
        """
        return self.dynamic_range_compression_torch(magnitudes)

    def dynamic_range_compression_torch(
        self,
        x: torch.Tensor,
        C: int = 1,
        clip_val: float = 1e-5,
    ) -> torch.Tensor:
        r"""Applies dynamic range compression to x.

        Args:
            x (torch.Tensor): Input tensor.
            C (float): Compression factor.
            clip_val (float): Clipping value.

        Returns:
            torch.Tensor: Output tensor.
        """
        return torch.log(torch.clamp(x, min=clip_val) * C)

    # NOTE: audio np.ndarray changed to torch.FloatTensor!
    def get_mel_from_wav(self, audio: torch.Tensor) -> torch.Tensor:
        audio_tensor = audio.unsqueeze(0)
        with torch.no_grad():
            _, melspec = self.forward(audio_tensor)
        return melspec.squeeze(0)

__init__(filter_length, hop_length, win_length, n_mel_channels, sampling_rate, center, mel_fmax, mel_fmin=0.0)

TacotronSTFT module that computes mel-spectrograms from a batch of waves.

Parameters:

Name Type Description Default
filter_length int

Length of the filter window.

required
hop_length int

Number of samples between successive frames.

required
win_length int

Size of the STFT window.

required
n_mel_channels int

Number of mel bins.

required
sampling_rate int

Sampling rate of the input waveforms.

required
mel_fmin int or None

Minimum frequency for the mel filter bank.

0.0
mel_fmax int or None

Maximum frequency for the mel filter bank.

required
center bool

Whether to pad the input signal on both sides.

required
Source code in training/preprocess/tacotron_stft.py
def __init__(
    self,
    filter_length: int,
    hop_length: int,
    win_length: int,
    n_mel_channels: int,
    sampling_rate: int,
    center: bool,
    mel_fmax: Optional[int],
    mel_fmin: float = 0.0,
):
    r"""TacotronSTFT module that computes mel-spectrograms from a batch of waves.

    Args:
        filter_length (int): Length of the filter window.
        hop_length (int): Number of samples between successive frames.
        win_length (int): Size of the STFT window.
        n_mel_channels (int): Number of mel bins.
        sampling_rate (int): Sampling rate of the input waveforms.
        mel_fmin (int or None): Minimum frequency for the mel filter bank.
        mel_fmax (int or None): Maximum frequency for the mel filter bank.
        center (bool): Whether to pad the input signal on both sides.
    """
    super().__init__()

    self.n_mel_channels = n_mel_channels
    self.sampling_rate = sampling_rate
    self.n_fft = filter_length
    self.hop_size = hop_length
    self.win_size = win_length
    self.fmin = mel_fmin
    self.fmax = mel_fmax
    self.center = center

    # Define the mel filterbank
    mel = librosa.filters.mel(
        sr=sampling_rate,
        n_fft=filter_length,
        n_mels=n_mel_channels,
        fmin=mel_fmin,
        fmax=mel_fmax,
    )

    mel_basis = torch.from_numpy(mel).float()

    # Define the Hann window
    hann_window = torch.hann_window(win_length)

    self.register_buffer("mel_basis", mel_basis)
    self.register_buffer("hann_window", hann_window)

dynamic_range_compression_torch(x, C=1, clip_val=1e-05)

Applies dynamic range compression to x.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required
C float

Compression factor.

1
clip_val float

Clipping value.

1e-05

Returns:

Type Description
Tensor

torch.Tensor: Output tensor.

Source code in training/preprocess/tacotron_stft.py
def dynamic_range_compression_torch(
    self,
    x: torch.Tensor,
    C: int = 1,
    clip_val: float = 1e-5,
) -> torch.Tensor:
    r"""Applies dynamic range compression to x.

    Args:
        x (torch.Tensor): Input tensor.
        C (float): Compression factor.
        clip_val (float): Clipping value.

    Returns:
        torch.Tensor: Output tensor.
    """
    return torch.log(torch.clamp(x, min=clip_val) * C)

forward(y)

Computes mel-spectrograms from a batch of waves.

Parameters:

Name Type Description Default
y FloatTensor

Input waveforms with shape (B, T) in range [-1, 1]

required

Returns:

Type Description
Tensor

torch.FloatTensor: Spectrogram of shape (B, n_spech_channels, T)

Tensor

torch.FloatTensor: Mel-spectrogram of shape (B, n_mel_channels, T)

Source code in training/preprocess/tacotron_stft.py
def forward(self, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    r"""Computes mel-spectrograms from a batch of waves.

    Args:
        y (torch.FloatTensor): Input waveforms with shape (B, T) in range [-1, 1]

    Returns:
        torch.FloatTensor: Spectrogram of shape (B, n_spech_channels, T)
        torch.FloatTensor: Mel-spectrogram of shape (B, n_mel_channels, T)
    """
    spec = self._spectrogram(y)

    spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))

    mel = torch.matmul(self.mel_basis, spec)  # type: ignore
    mel = self.spectral_normalize_torch(mel)

    return spec, mel

linear_spectrogram(y)

Computes the linear spectrogram of a batch of waves.

Parameters:

Name Type Description Default
y Tensor

Input waveforms.

required

Returns:

Type Description
Tensor

torch.Tensor: Linear spectrogram.

Source code in training/preprocess/tacotron_stft.py
def linear_spectrogram(self, y: torch.Tensor) -> torch.Tensor:
    r"""Computes the linear spectrogram of a batch of waves.

    Args:
        y (torch.Tensor): Input waveforms.

    Returns:
        torch.Tensor: Linear spectrogram.
    """
    spec = self._spectrogram(y)
    return torch.norm(spec, p=2, dim=-1)

spectral_normalize_torch(magnitudes)

Applies dynamic range compression to magnitudes.

Parameters:

Name Type Description Default
magnitudes Tensor

Input magnitudes.

required

Returns:

Type Description
Tensor

torch.Tensor: Output magnitudes.

Source code in training/preprocess/tacotron_stft.py
def spectral_normalize_torch(self, magnitudes: torch.Tensor) -> torch.Tensor:
    r"""Applies dynamic range compression to magnitudes.

    Args:
        magnitudes (torch.Tensor): Input magnitudes.

    Returns:
        torch.Tensor: Output magnitudes.
    """
    return self.dynamic_range_compression_torch(magnitudes)