Relative Multi-Head Attention
RelativeMultiHeadAttention
Bases: Module
Multi-head attention with relative positional encoding. This concept was proposed in the Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context
Parameters:
Name | Type | Description | Default |
---|---|---|---|
d_model |
int
|
The dimension of model |
512
|
num_heads |
int
|
The number of attention heads. |
16
|
query, key, value, pos_embedding, mask
- query (batch, time, dim): Tensor containing query vector
- key (batch, time, dim): Tensor containing key vector
- value (batch, time, dim): Tensor containing value vector
- pos_embedding (batch, time, dim): Positional embedding tensor
- mask (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked
Returns: - outputs: Tensor produces by relative multi head attention module.
Note: d_model
should be divisible by num_heads
in other words d_model % num_heads
should be zero.
Source code in models/tts/delightful_tts/attention/relative_multi_head_attention.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
|
forward(query, key, value, pos_embedding, mask)
Function applies multi-head attention along with relative positional encoding to the inputs. It restructures the input queries, keys, and values according to individual attention heads, applies biases, calculates content and position scores, and combines these to get the final score. A softmax activation is applied over the final score, followed by the calculation of context (contextual representation of input).
Performs the forward pass on the queries, keys, values, and positional embeddings with a mask.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
query |
Tensor
|
The input tensor containing query vectors. |
required |
key |
Tensor
|
The input tensor containing key vectors. |
required |
value |
Tensor
|
The input tensor containing value vectors. |
required |
pos_embedding |
Tensor
|
The positional embedding tensor. |
required |
mask |
Tensor
|
The mask tensor containing indices to be masked. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
Tuple[torch.Tensor, torch.Tensor]: The context and attention tensors. |
Tensor
|
Tensor produces by relative multi head attention module. |