Skip to content

Univnet loss

UnivnetLoss

Bases: Module

UnivnetLoss is a PyTorch Module that calculates the generator and discriminator losses for Univnet.

Source code in training/loss/univnet_loss.py
class UnivnetLoss(Module):
    r"""UnivnetLoss is a PyTorch Module that calculates the generator and discriminator losses for Univnet."""

    def __init__(self):
        r"""Initializes the UnivnetLoss module."""
        super().__init__()

        train_config = VocoderBasicConfig()

        self.stft_lamb = train_config.stft_lamb
        self.model_config = VocoderModelConfig()

        self.stft_criterion = MultiResolutionSTFTLoss(self.model_config.mrd.resolutions)
        self.esr_loss = ESRLoss()
        self.sisdr_loss = SISDRLoss()
        self.snr_loss = SNRLoss()
        self.sdsdr_loss = SDSDRLoss()

    def forward(
        self,
        audio: Tensor,
        fake_audio: Tensor,
        res_fake: List[Tuple[Tensor, Tensor]],
        period_fake: List[Tuple[Tensor, Tensor]],
        res_real: List[Tuple[Tensor, Tensor]],
        period_real: List[Tuple[Tensor, Tensor]],
    ) -> Tuple[
        Tensor,
        Tensor,
        Tensor,
        Tensor,
        Tensor,
        Tensor,
    ]:
        r"""Calculate the losses for the generator and discriminator.

        Args:
            audio (torch.Tensor): The real audio samples.
            fake_audio (torch.Tensor): The generated audio samples.
            res_fake (List[Tuple[Tensor, Tensor]]): The discriminator's output for the fake audio.
            period_fake (List[Tuple[Tensor, Tensor]]): The discriminator's output for the fake audio in the period.
            res_real (List[Tuple[Tensor, Tensor]]): The discriminator's output for the real audio.
            period_real (List[Tuple[Tensor, Tensor]]): The discriminator's output for the real audio in the period.

        Returns:
            tuple: A tuple containing the univnet loss, discriminator loss, STFT loss, score loss, ESR, SISDR, SNR and SDSDR losses.
        """
        # Calculate the STFT loss
        sc_loss, mag_loss = self.stft_criterion(fake_audio.squeeze(1), audio.squeeze(1))
        stft_loss = (sc_loss + mag_loss) * self.stft_lamb

        # Pad the fake audio to match the length of the real audio
        padding = audio.shape[2] - fake_audio.shape[2]
        fake_audio_padded = torch.nn.functional.pad(fake_audio, (0, padding))

        esr_loss = self.esr_loss.forward(fake_audio_padded, audio)
        snr_loss = self.snr_loss.forward(fake_audio_padded, audio)

        # Calculate the score loss
        score_loss = torch.tensor(0.0, device=audio.device)
        for _, score_fake in res_fake + period_fake:
            score_loss += torch.mean(torch.pow(score_fake - 1.0, 2))

        score_loss = score_loss / len(res_fake + period_fake)

        # Calculate the total generator loss
        total_loss_gen = score_loss + stft_loss + esr_loss + snr_loss

        # Calculate the discriminator loss
        total_loss_disc = torch.tensor(0.0, device=audio.device)
        for (_, score_fake), (_, score_real) in zip(
            res_fake + period_fake, res_real + period_real
        ):
            total_loss_disc += torch.mean(torch.pow(score_real - 1.0, 2)) + torch.mean(
                torch.pow(score_fake, 2)
            )

        total_loss_disc = total_loss_disc / len(res_fake + period_fake)

        return (
            total_loss_gen,
            total_loss_disc,
            stft_loss,
            score_loss,
            esr_loss,
            snr_loss,
        )

__init__()

Initializes the UnivnetLoss module.

Source code in training/loss/univnet_loss.py
def __init__(self):
    r"""Initializes the UnivnetLoss module."""
    super().__init__()

    train_config = VocoderBasicConfig()

    self.stft_lamb = train_config.stft_lamb
    self.model_config = VocoderModelConfig()

    self.stft_criterion = MultiResolutionSTFTLoss(self.model_config.mrd.resolutions)
    self.esr_loss = ESRLoss()
    self.sisdr_loss = SISDRLoss()
    self.snr_loss = SNRLoss()
    self.sdsdr_loss = SDSDRLoss()

forward(audio, fake_audio, res_fake, period_fake, res_real, period_real)

Calculate the losses for the generator and discriminator.

Parameters:

Name Type Description Default
audio Tensor

The real audio samples.

required
fake_audio Tensor

The generated audio samples.

required
res_fake List[Tuple[Tensor, Tensor]]

The discriminator's output for the fake audio.

required
period_fake List[Tuple[Tensor, Tensor]]

The discriminator's output for the fake audio in the period.

required
res_real List[Tuple[Tensor, Tensor]]

The discriminator's output for the real audio.

required
period_real List[Tuple[Tensor, Tensor]]

The discriminator's output for the real audio in the period.

required

Returns:

Name Type Description
tuple Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]

A tuple containing the univnet loss, discriminator loss, STFT loss, score loss, ESR, SISDR, SNR and SDSDR losses.

Source code in training/loss/univnet_loss.py
def forward(
    self,
    audio: Tensor,
    fake_audio: Tensor,
    res_fake: List[Tuple[Tensor, Tensor]],
    period_fake: List[Tuple[Tensor, Tensor]],
    res_real: List[Tuple[Tensor, Tensor]],
    period_real: List[Tuple[Tensor, Tensor]],
) -> Tuple[
    Tensor,
    Tensor,
    Tensor,
    Tensor,
    Tensor,
    Tensor,
]:
    r"""Calculate the losses for the generator and discriminator.

    Args:
        audio (torch.Tensor): The real audio samples.
        fake_audio (torch.Tensor): The generated audio samples.
        res_fake (List[Tuple[Tensor, Tensor]]): The discriminator's output for the fake audio.
        period_fake (List[Tuple[Tensor, Tensor]]): The discriminator's output for the fake audio in the period.
        res_real (List[Tuple[Tensor, Tensor]]): The discriminator's output for the real audio.
        period_real (List[Tuple[Tensor, Tensor]]): The discriminator's output for the real audio in the period.

    Returns:
        tuple: A tuple containing the univnet loss, discriminator loss, STFT loss, score loss, ESR, SISDR, SNR and SDSDR losses.
    """
    # Calculate the STFT loss
    sc_loss, mag_loss = self.stft_criterion(fake_audio.squeeze(1), audio.squeeze(1))
    stft_loss = (sc_loss + mag_loss) * self.stft_lamb

    # Pad the fake audio to match the length of the real audio
    padding = audio.shape[2] - fake_audio.shape[2]
    fake_audio_padded = torch.nn.functional.pad(fake_audio, (0, padding))

    esr_loss = self.esr_loss.forward(fake_audio_padded, audio)
    snr_loss = self.snr_loss.forward(fake_audio_padded, audio)

    # Calculate the score loss
    score_loss = torch.tensor(0.0, device=audio.device)
    for _, score_fake in res_fake + period_fake:
        score_loss += torch.mean(torch.pow(score_fake - 1.0, 2))

    score_loss = score_loss / len(res_fake + period_fake)

    # Calculate the total generator loss
    total_loss_gen = score_loss + stft_loss + esr_loss + snr_loss

    # Calculate the discriminator loss
    total_loss_disc = torch.tensor(0.0, device=audio.device)
    for (_, score_fake), (_, score_real) in zip(
        res_fake + period_fake, res_real + period_real
    ):
        total_loss_disc += torch.mean(torch.pow(score_real - 1.0, 2)) + torch.mean(
            torch.pow(score_fake, 2)
        )

    total_loss_disc = total_loss_disc / len(res_fake + period_fake)

    return (
        total_loss_gen,
        total_loss_disc,
        stft_loss,
        score_loss,
        esr_loss,
        snr_loss,
    )