Dive into Deep Learning
Table Of Contents
Dive into Deep Learning
Table Of Contents

9.1. Attention Mechanism

In Section 8.14, we encode the source sequence input information in the recurrent unit state and then pass it to the decoder to generate the target sequence. A token in the target sequence may closely relate to some tokens in the source sequence instead of the whole source sequence. For example, when translating “Hello world.” to “Bonjour le monde.”, “Bonjour” maps to “Hello” and “monde” maps to “world”. In the seq2seq model, the decoder may implicitly select the corresponding information from the state passed by the decoder. The attention mechanism, however, makes this selection explicit.

Attention is a generalized pooling method with bias alignment over inputs. The core component in the attention mechanism is the attention layer, or called attention for simplicity. An input of the attention layer is called a query. For a query, the attention layer returns the output based on its memory, which is a set of key-value pairs. To be more specific, assume a query \(\mathbf{q}\in\mathbb R^{d_q}\), and the memory contains \(n\) key-value pairs, \((\mathbf{k}_1, \mathbf{v}_1), \ldots, (\mathbf{k}_n, \mathbf{v}_n)\), with \(\mathbf{k}_i\in\mathbb R^{d_k}\), \(\mathbf{v}_i\in\mathbb R^{d_v}\). The attention layer then returns an output \(\mathbf o\in\mathbb R^{d_v}\) with the same shape as a value.


Fig. 9.1.1 The attention layer returns an output based on the input query and its memory.

To compute the output, we first assume there is a score function \(\alpha\) which measures the similarity between the query and a key. Then we compute all \(n\) scores \(a_1, \ldots, a_n\) by

(9.1.1)\[a_i = \alpha(\mathbf q, \mathbf k_i).\]

Next we use softmax to obtain the attention weights

(9.1.2)\[b_1, \ldots, b_n = \textrm{softmax}(a_1, \ldots, a_n).\]

The output is then a weighted sum of the values

(9.1.3)\[\mathbf o = \sum_{i=1}^n b_i \mathbf v_i.\]

Different choices of the score function lead to different attention layers. We will discuss two commonly used attention layers in the rest of this section. Before diving into the implementation, we first introduce a masked version of the softmax operator and explain a specialized dot operator nd.batched_dot.

import math
from mxnet import nd
from mxnet.gluon import nn

The masked softmax takes a 3-dim input and allows us to filter out some elements by specifying valid lengths for the last dimension. (Refer to Section 8.12 for the definition of a valid length.)

# Save to the d2l package.
def masked_softmax(X, valid_length):
    # X: 3-D tensor, valid_length: 1-D or 2-D tensor
    if valid_length is None:
        return X.softmax()
        shape = X.shape
        if valid_length.ndim == 1:
            valid_length = valid_length.repeat(shape[1], axis=0)
            valid_length = valid_length.reshape((-1,))
        # fill masked elements with a large negative, whose exp is 0
        X = nd.SequenceMask(X.reshape((-1, shape[-1])), valid_length, True,
                            axis=1, value=-1e6)
        return X.softmax().reshape(shape)

Construct two examples, where each example is a 2-by-4 matrix, as the input. If we specify the valid length for the first example to be 2, then only the first two columns of this example are used to compute softmax.

masked_softmax(nd.random.uniform(shape=(2,2,4)), nd.array([2,3]))
[[[0.488994   0.511006   0.         0.        ]
  [0.43654838 0.56345165 0.         0.        ]]

 [[0.28817102 0.3519408  0.3598882  0.        ]
  [0.29034293 0.25239873 0.45725834 0.        ]]]
<NDArray 2x2x4 @cpu(0)>

The operator nd.batched_dot takes two inputs \(X\) and \(Y\) with shapes \((b, n, m)\) and \((b, m, k)\), respectively. It computes \(b\) dot products, with Z[i,:,:]=dot(X[i,:,:], Y[i,:,:] for \(i=1,\ldots,n\).

nd.batch_dot(nd.ones((2,1,3)), nd.ones((2,3,2)))
[[[3. 3.]]

 [[3. 3.]]]
<NDArray 2x1x2 @cpu(0)>

9.1.1. Dot Product Attention

The dot product assumes the query has the same dimension as the keys, namely \(\mathbf q, \mathbf k_i \in\mathbb R^d\) for all \(i\). It computes the score by an inner product between the query and a key, often then divided by \(\sqrt{d}\) to make the scores less sensitive to the dimension \(d\). In other words,

(9.1.4)\[\alpha(\mathbf q, \mathbf k) = \langle \mathbf q, \mathbf k \rangle /\sqrt{d}.\]

Assume \(\mathbf Q\in\mathbb R^{m\times d}\) contains \(m\) queries and \(\mathbf K\in\mathbb R^{n\times d}\) has all \(n\) keys. We can compute all \(mn\) scores by

(9.1.5)\[\alpha(\mathbf Q, \mathbf K) = \mathbf Q \mathbf K^T /\sqrt{d}.\]

Now let’s implement this layer that supports a batch of queries and key-value pairs. In addition, it supports randomly dropping some attention weights as a regularization.

# Save to the d2l package.
class DotProductAttention(nn.Block):
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)

    # query: (batch_size, #queries, d)
    # key: (batch_size, #kv_pairs, d)
    # value: (batch_size, #kv_pairs, dim_v)
    # valid_length: either (batch_size, ) or (batch_size, xx)
    def forward(self, query, key, value, valid_length=None):
        d = query.shape[-1]
        # set transpose_b=True to swap the last two dimensions of key
        scores = nd.batch_dot(query, key, transpose_b=True) / math.sqrt(d)
        attention_weights = self.dropout(masked_softmax(scores, valid_length))
        return nd.batch_dot(attention_weights, value)

Now we create two batches, and each batch has one query and 10 key-value pairs. We specify through valid_length that for the first batch, we will only pay attention to the first 2 key-value pairs, while for the second batch, we will check the first 6 key-value pairs. Therefore, though both batches have the same query and key-value pairs, we obtain different outputs.

atten = DotProductAttention(dropout=0.5)
keys = nd.ones((2,10,2))
values = nd.arange(40).reshape((1,10,4)).repeat(2,axis=0)
atten(nd.ones((2,1,2)), keys, values, nd.array([2, 6]))
[[[ 2.        3.        4.        5.      ]]

 [[10.       11.       12.000001 13.      ]]]
<NDArray 2x1x4 @cpu(0)>

9.1.2. Multilayer Perceptron Attention

In multilayer perceptron attention, we first project both query and keys into \(\mathbb R^{h}\).

To be more specific, assume learnable parameters \(\mathbf W_k\in\mathbb R^{h\times d_k}\), \(\mathbf W_q\in\mathbb R^{h\times d_q}\), and \(\mathbf v\in\mathbb R^{p}\). Then the score function is defined by

(9.1.6)\[\alpha(\mathbf k, \mathbf q) = \mathbf v^T \text{tanh}(\mathbf W_k \mathbf k + \mathbf W_q\mathbf q).\]

This concatenates the key and value in the feature dimension and feeds them into a single hidden layer perceptron with hidden layer size \(h\) and output layer size \(1\). The hidden layer activation function is tanh and no bias is applied.

# Save to the d2l package.
class MLPAttention(nn.Block):
    def __init__(self, units, dropout, **kwargs):
        super(MLPAttention, self).__init__(**kwargs)
        # Use flatten=True to keep query's and key's 3-D shapes.
        self.W_k = nn.Dense(units, activation='tanh',
                            use_bias=False, flatten=False)
        self.W_q = nn.Dense(units, activation='tanh',
                            use_bias=False, flatten=False)
        self.v = nn.Dense(1, use_bias=False, flatten=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, valid_length):
        query, key = self.W_k(query), self.W_q(key)
        # expand query to (batch_size, #querys, 1, units), and key to
        # (batch_size, 1, #kv_pairs, units). Then plus them with broadcast.
        features = query.expand_dims(axis=2) + key.expand_dims(axis=1)
        scores = self.v(features).squeeze(axis=-1)
        attention_weights = self.dropout(masked_softmax(scores, valid_length))
        return nd.batch_dot(attention_weights, value)

Despite MLPAttention containing an additional MLP model, given the same inputs with identical keys, we obtain the same output as for DotProductAttention.

atten = MLPAttention(units=8, dropout=0.1)
atten(nd.ones((2,1,2)), keys, values, nd.array([2, 6]))
[[[ 2.        3.        4.        5.      ]]

 [[10.       11.       12.000001 13.      ]]]
<NDArray 2x1x4 @cpu(0)>

9.1.3. Summary

  • An attention layer explicitly selects related information.
  • An attention layer’s memory consists of key-value pairs, so its output is close to the values whose keys are similar to the query.