Skip to content

Multi Period Discriminator

MultiPeriodDiscriminator

Bases: Module

MultiPeriodDiscriminator is a class that implements a multi-period discriminator network for the UnivNet vocoder.

Parameters:

Name Type Description Default
model_config VocoderModelConfig

The configuration object for the UnivNet vocoder model.

required
Source code in models/vocoder/univnet/multi_period_discriminator.py
class MultiPeriodDiscriminator(Module):
    r"""MultiPeriodDiscriminator is a class that implements a multi-period discriminator network for the UnivNet vocoder.

    Args:
        model_config (VocoderModelConfig): The configuration object for the UnivNet vocoder model.
    """

    def __init__(
        self,
        model_config: VocoderModelConfig,
    ):
        super().__init__()

        self.discriminators = nn.ModuleList(
            [
                DiscriminatorP(period, model_config=model_config)
                for period in model_config.mpd.periods
            ],
        )

    def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
        r"""Forward pass of the multi-period discriminator network.

        Args:
            x (torch.Tensor): The input tensor of shape (batch_size, channels, time_steps).

        Returns:
            list: A list of output tensors from each discriminator network.
        """
        return [disc(x) for disc in self.discriminators]

forward(x)

Forward pass of the multi-period discriminator network.

Parameters:

Name Type Description Default
x Tensor

The input tensor of shape (batch_size, channels, time_steps).

required

Returns:

Name Type Description
list list[Tensor]

A list of output tensors from each discriminator network.

Source code in models/vocoder/univnet/multi_period_discriminator.py
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
    r"""Forward pass of the multi-period discriminator network.

    Args:
        x (torch.Tensor): The input tensor of shape (batch_size, channels, time_steps).

    Returns:
        list: A list of output tensors from each discriminator network.
    """
    return [disc(x) for disc in self.discriminators]