Skip to content

Multi-Head Attention

MultiHeadAttention

Bases: Module

A class that implements a Multi-head Attention mechanism. Multi-head attention allows the model to focus on different positions, capturing various aspects of the input.

Parameters:

Name Type Description Default
query_dim int

The dimensionality of the query.

required
key_dim int

The dimensionality of the key.

required
num_units int

The total number of dimensions of the output.

required
num_heads int

The number of parallel attention layers (multi-heads).

required
query, and key
  • query: Tensor of shape [N, T_q, query_dim]
  • key: Tensor of shape [N, T_k, key_dim]
Outputs
  • An output tensor of shape [N, T_q, num_units]
Source code in models/tts/delightful_tts/attention/multi_head_attention.py
class MultiHeadAttention(Module):
    r"""A class that implements a Multi-head Attention mechanism.
    Multi-head attention allows the model to focus on different positions,
    capturing various aspects of the input.

    Args:
        query_dim (int): The dimensionality of the query.
        key_dim (int): The dimensionality of the key.
        num_units (int): The total number of dimensions of the output.
        num_heads (int): The number of parallel attention layers (multi-heads).

    Inputs: query, and key
        - **query**: Tensor of shape [N, T_q, query_dim]
        - **key**: Tensor of shape [N, T_k, key_dim]

    Outputs:
        - An output tensor of shape [N, T_q, num_units]
    """

    def __init__(
        self,
        query_dim: int,
        key_dim: int,
        num_units: int,
        num_heads: int,
    ):
        super().__init__()
        self.num_units = num_units
        self.num_heads = num_heads
        self.key_dim = key_dim

        self.W_query = nn.Linear(
            in_features=query_dim,
            out_features=num_units,
            bias=False,
        )
        self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
        self.W_value = nn.Linear(
            in_features=key_dim, out_features=num_units, bias=False,
        )

    def forward(self, query: torch.Tensor, key: torch.Tensor) -> torch.Tensor:
        r"""Performs the forward pass over input tensors.

        Args:
            query (torch.Tensor): The input tensor containing query vectors.
                It is expected to have the dimensions [N, T_q, query_dim]
                where N is the batch size, T_q is the sequence length of queries,
                and query_dim is the dimensionality of a single query vector.

            key (torch.Tensor): The input tensor containing key vectors.
                It is expected to have the dimensions [N, T_k, key_dim]
                where N is the batch size, T_k is the sequence length of keys,
                and key_dim is the dimensionality of a single key vector.

        Returns:
            torch.Tensor: The output tensor of shape [N, T_q, num_units] which
                represents the results of the multi-head attention mechanism applied
                on the provided queries and keys.
        """
        querys = self.W_query(query)  # [N, T_q, num_units]
        keys = self.W_key(key)  # [N, T_k, num_units]
        values = self.W_value(key)
        split_size = self.num_units // self.num_heads

        querys = torch.stack(
            torch.split(querys, split_size, dim=2), dim=0,
        )  # [h, N, T_q, num_units/h]
        keys = torch.stack(
            torch.split(keys, split_size, dim=2), dim=0,
        )  # [h, N, T_k, num_units/h]
        values = torch.stack(
            torch.split(values, split_size, dim=2), dim=0,
        )  # [h, N, T_k, num_units/h]
        # score = softmax(QK^T / (d_k ** 0.5))

        scores = torch.matmul(querys, keys.transpose(2, 3))  # [h, N, T_q, T_k]
        scores = scores / (self.key_dim**0.5)
        scores = F.softmax(scores, dim=3)
        # out = score * V
        out = torch.matmul(scores, values)  # [h, N, T_q, num_units/h]
        return torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(
            0,
        )  # [N, T_q, num_units]

forward(query, key)

Performs the forward pass over input tensors.

Parameters:

Name Type Description Default
query Tensor

The input tensor containing query vectors. It is expected to have the dimensions [N, T_q, query_dim] where N is the batch size, T_q is the sequence length of queries, and query_dim is the dimensionality of a single query vector.

required
key Tensor

The input tensor containing key vectors. It is expected to have the dimensions [N, T_k, key_dim] where N is the batch size, T_k is the sequence length of keys, and key_dim is the dimensionality of a single key vector.

required

Returns:

Type Description
Tensor

torch.Tensor: The output tensor of shape [N, T_q, num_units] which represents the results of the multi-head attention mechanism applied on the provided queries and keys.

Source code in models/tts/delightful_tts/attention/multi_head_attention.py
def forward(self, query: torch.Tensor, key: torch.Tensor) -> torch.Tensor:
    r"""Performs the forward pass over input tensors.

    Args:
        query (torch.Tensor): The input tensor containing query vectors.
            It is expected to have the dimensions [N, T_q, query_dim]
            where N is the batch size, T_q is the sequence length of queries,
            and query_dim is the dimensionality of a single query vector.

        key (torch.Tensor): The input tensor containing key vectors.
            It is expected to have the dimensions [N, T_k, key_dim]
            where N is the batch size, T_k is the sequence length of keys,
            and key_dim is the dimensionality of a single key vector.

    Returns:
        torch.Tensor: The output tensor of shape [N, T_q, num_units] which
            represents the results of the multi-head attention mechanism applied
            on the provided queries and keys.
    """
    querys = self.W_query(query)  # [N, T_q, num_units]
    keys = self.W_key(key)  # [N, T_k, num_units]
    values = self.W_value(key)
    split_size = self.num_units // self.num_heads

    querys = torch.stack(
        torch.split(querys, split_size, dim=2), dim=0,
    )  # [h, N, T_q, num_units/h]
    keys = torch.stack(
        torch.split(keys, split_size, dim=2), dim=0,
    )  # [h, N, T_k, num_units/h]
    values = torch.stack(
        torch.split(values, split_size, dim=2), dim=0,
    )  # [h, N, T_k, num_units/h]
    # score = softmax(QK^T / (d_k ** 0.5))

    scores = torch.matmul(querys, keys.transpose(2, 3))  # [h, N, T_q, T_k]
    scores = scores / (self.key_dim**0.5)
    scores = F.softmax(scores, dim=3)
    # out = score * V
    out = torch.matmul(scores, values)  # [h, N, T_q, num_units/h]
    return torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(
        0,
    )  # [N, T_q, num_units]