Skip to content

Style Embed Attention

StyleEmbedAttention

Bases: Module

Mechanism is being used to extract style features from audio data in the form of spectrograms.

Each style token (parameterized by an embedding vector) represents a unique style feature. The model applies the StyleEmbedAttention mechanism to combine these style tokens (style features) in a weighted manner. The output of the attention module is a sum of style tokens, with each token weighted by its relevance to the input.

This technique is often used in text-to-speech synthesis (TTS) such as Tacotron-2, where the goal is to modulate the prosody, stress, and intonation of the synthesized speech based on the reference audio or some control parameters. The concept of "global style tokens" (GST) was introduced in Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End Speech Synthesis by Yuxuan Wang et al.

The StyleEmbedAttention class is a PyTorch module implementing the attention mechanism. This class is specifically designed for handling multiple attention heads. Attention here operates on a query and a set of key-value pairs to produce an output.

Builds the StyleEmbedAttention network.

Parameters:

Name Type Description Default
query_dim int

Dimensionality of the query vectors.

required
key_dim int

Dimensionality of the key vectors.

required
num_units int

Total dimensionality of the query, key, and value vectors.

required
num_heads int

Number of parallel attention layers (heads).

required

Note: num_units should be divisible by num_heads.

Source code in models/tts/delightful_tts/attention/style_embed_attention.py
class StyleEmbedAttention(Module):
    r"""Mechanism is being used to extract style features from audio data in the form of spectrograms.

    Each style token (parameterized by an embedding vector) represents a unique style feature. The model applies the `StyleEmbedAttention` mechanism to combine these style tokens (style features) in a weighted manner. The output of the attention module is a sum of style tokens, with each token weighted by its relevance to the input.

    This technique is often used in text-to-speech synthesis (TTS) such as Tacotron-2, where the goal is to modulate the prosody, stress, and intonation of the synthesized speech based on the reference audio or some control parameters. The concept of "global style tokens" (GST) was introduced in
    [Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End Speech Synthesis](https://arxiv.org/abs/1803.09017) by Yuxuan Wang et al.

    The `StyleEmbedAttention` class is a PyTorch module implementing the attention mechanism.
    This class is specifically designed for handling multiple attention heads.
    Attention here operates on a query and a set of key-value pairs to produce an output.

    Builds the `StyleEmbedAttention` network.

    Args:
        query_dim (int): Dimensionality of the query vectors.
        key_dim (int): Dimensionality of the key vectors.
        num_units (int): Total dimensionality of the query, key, and value vectors.
        num_heads (int): Number of parallel attention layers (heads).

    Note: `num_units` should be divisible by `num_heads`.
    """

    def __init__(
        self,
        query_dim: int,
        key_dim: int,
        num_units: int,
        num_heads: int,
    ):
        super().__init__()
        self.num_units = num_units
        self.num_heads = num_heads
        self.key_dim = key_dim

        self.W_query = nn.Linear(
            in_features=query_dim,
            out_features=num_units,
            bias=False,
        )
        self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
        self.W_value = nn.Linear(
            in_features=key_dim, out_features=num_units, bias=False,
        )

    def forward(self, query: torch.Tensor, key_soft: torch.Tensor) -> torch.Tensor:
        r"""Forward pass of the StyleEmbedAttention module calculates the attention scores.

        Args:
            query (torch.Tensor): The input tensor for queries of shape `[N, T_q, query_dim]`
            key_soft (torch.Tensor): The input tensor for keys of shape `[N, T_k, key_dim]`

        Returns:
            out (torch.Tensor): The output tensor of shape `[N, T_q, num_units]`
        """
        values = self.W_value(key_soft)
        split_size = self.num_units // self.num_heads
        values = torch.stack(torch.split(values, split_size, dim=2), dim=0)

        # out_soft = scores_soft = None
        queries = self.W_query(query)  # [N, T_q, num_units]
        keys = self.W_key(key_soft)  # [N, T_k, num_units]

        # [h, N, T_q, num_units/h]
        queries = torch.stack(torch.split(queries, split_size, dim=2), dim=0)
        # [h, N, T_k, num_units/h]
        keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0)
        # [h, N, T_k, num_units/h]

        # score = softmax(QK^T / (d_k ** 0.5))
        scores_soft = torch.matmul(queries, keys.transpose(2, 3))  # [h, N, T_q, T_k]
        scores_soft = scores_soft / (self.key_dim**0.5)
        scores_soft = F.softmax(scores_soft, dim=3)

        # out = score * V
        # [h, N, T_q, num_units/h]
        out_soft = torch.matmul(scores_soft, values)
        return torch.cat(torch.split(out_soft, 1, dim=0), dim=3).squeeze(
            0,
        )  # [N, T_q, num_units] scores_soft

forward(query, key_soft)

Forward pass of the StyleEmbedAttention module calculates the attention scores.

Parameters:

Name Type Description Default
query Tensor

The input tensor for queries of shape [N, T_q, query_dim]

required
key_soft Tensor

The input tensor for keys of shape [N, T_k, key_dim]

required

Returns:

Name Type Description
out Tensor

The output tensor of shape [N, T_q, num_units]

Source code in models/tts/delightful_tts/attention/style_embed_attention.py
def forward(self, query: torch.Tensor, key_soft: torch.Tensor) -> torch.Tensor:
    r"""Forward pass of the StyleEmbedAttention module calculates the attention scores.

    Args:
        query (torch.Tensor): The input tensor for queries of shape `[N, T_q, query_dim]`
        key_soft (torch.Tensor): The input tensor for keys of shape `[N, T_k, key_dim]`

    Returns:
        out (torch.Tensor): The output tensor of shape `[N, T_q, num_units]`
    """
    values = self.W_value(key_soft)
    split_size = self.num_units // self.num_heads
    values = torch.stack(torch.split(values, split_size, dim=2), dim=0)

    # out_soft = scores_soft = None
    queries = self.W_query(query)  # [N, T_q, num_units]
    keys = self.W_key(key_soft)  # [N, T_k, num_units]

    # [h, N, T_q, num_units/h]
    queries = torch.stack(torch.split(queries, split_size, dim=2), dim=0)
    # [h, N, T_k, num_units/h]
    keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0)
    # [h, N, T_k, num_units/h]

    # score = softmax(QK^T / (d_k ** 0.5))
    scores_soft = torch.matmul(queries, keys.transpose(2, 3))  # [h, N, T_q, T_k]
    scores_soft = scores_soft / (self.key_dim**0.5)
    scores_soft = F.softmax(scores_soft, dim=3)

    # out = score * V
    # [h, N, T_q, num_units/h]
    out_soft = torch.matmul(scores_soft, values)
    return torch.cat(torch.split(out_soft, 1, dim=0), dim=3).squeeze(
        0,
    )  # [N, T_q, num_units] scores_soft