Skip to content

Multi Resolution STFT Loss

MultiResolutionSTFTLoss

Bases: Module

Multi resolution STFT loss module.

The Multi resolution STFT loss module is a PyTorch module that computes the spectral convergence and log STFT magnitude losses for a predicted signal and a groundtruth signal at multiple resolutions. The module is designed for speech and audio signal processing tasks, such as speech enhancement and source separation.

The module takes as input a list of tuples, where each tuple contains the FFT size, hop size, and window length for a particular resolution. For each resolution, the module computes the spectral convergence and log STFT magnitude losses using the STFTLoss module, which is a PyTorch module that computes the STFT of a signal and the corresponding magnitude spectrogram.

The spectral convergence loss measures the similarity between two magnitude spectrograms, while the log STFT magnitude loss measures the similarity between two logarithmically-scaled magnitude spectrograms. The logarithm is applied to the magnitude spectrograms to convert them to a decibel scale, which is more perceptually meaningful than the linear scale.

The Multi resolution STFT loss module returns the average spectral convergence and log STFT magnitude losses across all resolutions. This allows the module to capture both fine-grained and coarse-grained spectral information in the predicted and groundtruth signals.

Source code in training/loss/multi_resolution_stft_loss.py
class MultiResolutionSTFTLoss(Module):
    r"""Multi resolution STFT loss module.

    The Multi resolution STFT loss module is a PyTorch module that computes the spectral convergence and log STFT magnitude losses for a predicted signal and a groundtruth signal at multiple resolutions. The module is designed for speech and audio signal processing tasks, such as speech enhancement and source separation.

    The module takes as input a list of tuples, where each tuple contains the FFT size, hop size, and window length for a particular resolution. For each resolution, the module computes the spectral convergence and log STFT magnitude losses using the STFTLoss module, which is a PyTorch module that computes the STFT of a signal and the corresponding magnitude spectrogram.

    The spectral convergence loss measures the similarity between two magnitude spectrograms, while the log STFT magnitude loss measures the similarity between two logarithmically-scaled magnitude spectrograms. The logarithm is applied to the magnitude spectrograms to convert them to a decibel scale, which is more perceptually meaningful than the linear scale.

    The Multi resolution STFT loss module returns the average spectral convergence and log STFT magnitude losses across all resolutions. This allows the module to capture both fine-grained and coarse-grained spectral information in the predicted and groundtruth signals.
    """

    def __init__(
        self,
        resolutions: list[tuple[int, int, int]],
    ):
        r"""Initialize Multi resolution STFT loss module.

        Args:
            resolutions (list): List of (FFT size, shift size, window length).
        """
        super().__init__()

        self.stft_losses = torch.nn.ModuleList(
            [STFTLoss(fs, ss, wl) for fs, ss, wl in resolutions],
        )

    def forward(
        self, x: torch.Tensor, y: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        r"""Calculate forward propagation.

        Args:
            x (Tensor): Predicted signal (B, T).
            y (Tensor): Groundtruth signal (B, T).

        Returns:
            Tensor: Multi resolution spectral convergence loss value.
            Tensor: Multi resolution log STFT magnitude loss value.
        """
        sc_loss = torch.tensor(0.0, device=x.device)
        mag_loss = torch.tensor(0.0, device=x.device)

        # Compute the spectral convergence and log STFT magnitude losses for each resolution
        for f in self.stft_losses:
            sc_l, mag_l = f(x, y)
            sc_loss += sc_l
            mag_loss += mag_l

        # Average the losses across all resolutions
        sc_loss /= len(self.stft_losses)
        mag_loss /= len(self.stft_losses)

        return sc_loss, mag_loss

__init__(resolutions)

Initialize Multi resolution STFT loss module.

Parameters:

Name Type Description Default
resolutions list

List of (FFT size, shift size, window length).

required
Source code in training/loss/multi_resolution_stft_loss.py
def __init__(
    self,
    resolutions: list[tuple[int, int, int]],
):
    r"""Initialize Multi resolution STFT loss module.

    Args:
        resolutions (list): List of (FFT size, shift size, window length).
    """
    super().__init__()

    self.stft_losses = torch.nn.ModuleList(
        [STFTLoss(fs, ss, wl) for fs, ss, wl in resolutions],
    )

forward(x, y)

Calculate forward propagation.

Parameters:

Name Type Description Default
x Tensor

Predicted signal (B, T).

required
y Tensor

Groundtruth signal (B, T).

required

Returns:

Name Type Description
Tensor Tensor

Multi resolution spectral convergence loss value.

Tensor Tensor

Multi resolution log STFT magnitude loss value.

Source code in training/loss/multi_resolution_stft_loss.py
def forward(
    self, x: torch.Tensor, y: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    r"""Calculate forward propagation.

    Args:
        x (Tensor): Predicted signal (B, T).
        y (Tensor): Groundtruth signal (B, T).

    Returns:
        Tensor: Multi resolution spectral convergence loss value.
        Tensor: Multi resolution log STFT magnitude loss value.
    """
    sc_loss = torch.tensor(0.0, device=x.device)
    mag_loss = torch.tensor(0.0, device=x.device)

    # Compute the spectral convergence and log STFT magnitude losses for each resolution
    for f in self.stft_losses:
        sc_l, mag_l = f(x, y)
        sc_loss += sc_l
        mag_loss += mag_l

    # Average the losses across all resolutions
    sc_loss /= len(self.stft_losses)
    mag_loss /= len(self.stft_losses)

    return sc_loss, mag_loss