Skip to content

Multi Resolution Discriminator

MultiResolutionDiscriminator

Bases: Module

Multi-resolution discriminator for the UnivNet vocoder.

This class implements a multi-resolution discriminator that consists of multiple DiscriminatorR instances, each operating at a different resolution.

Parameters:

Name Type Description Default
model_config VocoderModelConfig

Model configuration object.

required

Attributes:

Name Type Description
resolutions list

List of resolutions for each DiscriminatorR instance.

discriminators ModuleList

List of DiscriminatorR instances.

Methods:

Name Description
forward

Computes the forward pass of the multi-resolution discriminator.

Source code in models/vocoder/univnet/multi_resolution_discriminator.py
class MultiResolutionDiscriminator(Module):
    r"""Multi-resolution discriminator for the UnivNet vocoder.

    This class implements a multi-resolution discriminator that consists of multiple DiscriminatorR instances, each operating at a different resolution.

    Args:
        model_config (VocoderModelConfig): Model configuration object.

    Attributes:
        resolutions (list): List of resolutions for each DiscriminatorR instance.
        discriminators (nn.ModuleList): List of DiscriminatorR instances.

    Methods:
        forward(x): Computes the forward pass of the multi-resolution discriminator.

    """

    def __init__(
        self,
        model_config: VocoderModelConfig,
    ):
        super().__init__()

        self.resolutions = model_config.mrd.resolutions
        self.discriminators = nn.ModuleList(
            [
                DiscriminatorR(resolution, model_config=model_config)
                for resolution in self.resolutions
            ],
        )

    def forward(self, x: torch.Tensor) -> list[tuple[torch.Tensor, torch.Tensor]]:
        r"""Computes the forward pass of the multi-resolution discriminator.

        Args:
            x (torch.Tensor): Input tensor of shape [B, C, T].

        Returns:
            list: List of tuples containing the intermediate feature maps and the output scores for each `DiscriminatorR` instance.
        """
        return [disc(x) for disc in self.discriminators] # [(feat, score), (feat, score), (feat, score)]

forward(x)

Computes the forward pass of the multi-resolution discriminator.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape [B, C, T].

required

Returns:

Name Type Description
list list[tuple[Tensor, Tensor]]

List of tuples containing the intermediate feature maps and the output scores for each DiscriminatorR instance.

Source code in models/vocoder/univnet/multi_resolution_discriminator.py
def forward(self, x: torch.Tensor) -> list[tuple[torch.Tensor, torch.Tensor]]:
    r"""Computes the forward pass of the multi-resolution discriminator.

    Args:
        x (torch.Tensor): Input tensor of shape [B, C, T].

    Returns:
        list: List of tuples containing the intermediate feature maps and the output scores for each `DiscriminatorR` instance.
    """
    return [disc(x) for disc in self.discriminators] # [(feat, score), (feat, score), (feat, score)]