Dive into Deep Learning
Table Of Contents
Dive into Deep Learning
Table Of Contents

7.9. Long Short-term Memory (LSTM)

@TODO(smolix/astonzhang): the data set was just changed from lyrics to time machine, so descriptions/hyperparameters have to change.

This section describes another commonly used gated recurrent neural network: long short-term memory (LSTM) [1]. Its structure is slightly more complicated than that of a gated recurrent unit.

7.9.1. Long Short-term Memory

Three gates are introduced in LSTM: the input gate, the forget gate, and the output gate, as well as memory cells in the same shape as the hidden state (some literature treats memory cells as a special kind of hidden state) used to record additional information.

7.9.1.1. Input Gates, Forget Gates, and Output Gates

Like the reset gate and the update gate in the gated recurrent unit, as shown in Figure 6.7, the input of LSTM gates is the current time step input \(\boldsymbol{X}_t\) and the hidden state of the previous time step \(\boldsymbol{H}_{t-1}\). The output is computed by the fully connected layer with a sigmoid function as its activation function. As a result, the three gate elements all have a value range of \([0,1]\).

Calculation of input, forget, and output gates in LSTM.

Fig. 7.9 Calculation of input, forget, and output gates in LSTM.

Here, we assume there are \(h\) hidden units and, for a given time step \(t\), the mini-batch input is \(\boldsymbol{X}_t \in \mathbb{R}^{n \times d}\) (number of examples: \(n\), number of inputs: \(d\))and the hidden state of the last time step is \(\boldsymbol{H}_{t-1} \in \mathbb{R}^{n \times h}\). For time step \(t\), the input gate \(\boldsymbol{I}_t \in \mathbb{R}^{n \times h}\), forget gate \(\boldsymbol{F}_t \in \mathbb{R}^{n \times h}\), and output gate \(\boldsymbol{O}_t \in \mathbb{R}^{n \times h}\) are calculated as follows:

\[\begin{split}\begin{aligned} \boldsymbol{I}_t &= \sigma(\boldsymbol{X}_t \boldsymbol{W}_{xi} + \boldsymbol{H}_{t-1} \boldsymbol{W}_{hi} + \boldsymbol{b}_i),\\ \boldsymbol{F}_t &= \sigma(\boldsymbol{X}_t \boldsymbol{W}_{xf} + \boldsymbol{H}_{t-1} \boldsymbol{W}_{hf} + \boldsymbol{b}_f),\\ \boldsymbol{O}_t &= \sigma(\boldsymbol{X}_t \boldsymbol{W}_{xo} + \boldsymbol{H}_{t-1} \boldsymbol{W}_{ho} + \boldsymbol{b}_o), \end{aligned}\end{split}\]

Here, \(\boldsymbol{W}_{xi}, \boldsymbol{W}_{xf}, \boldsymbol{W}_{xo} \in \mathbb{R}^{d \times h}\) and \(\boldsymbol{W}_{hi}, \boldsymbol{W}_{hf}, \boldsymbol{W}_{ho} \in \mathbb{R}^{h \times h}\) are weight parameters and \(\boldsymbol{b}_i, \boldsymbol{b}_f, \boldsymbol{b}_o \in \mathbb{R}^{1 \times h}\) is a bias parameter.

7.9.1.2. Candidate Memory Cells

Next, LSTM needs to compute the candidate memory cell \(\tilde{\boldsymbol{C}}_t\). Its computation is similar to the three gates described above, but using a tanh function with a value range for $[-1, 1] as activation function, as shown in Figure 6.8.

Computation of candidate memory cells in LSTM.

Fig. 7.10 Computation of candidate memory cells in LSTM.

For time step \(t\), the candidate memory cell \(\tilde{\boldsymbol{C}}_t \in \mathbb{R}^{n \times h}\) is calculated by the following formula:

\[\tilde{\boldsymbol{C}}_t = \text{tanh}(\boldsymbol{X}_t \boldsymbol{W}_{xc} + \boldsymbol{H}_{t-1} \boldsymbol{W}_{hc} + \boldsymbol{b}_c),\]

Here, \(\boldsymbol{W}_{xc} \in \mathbb{R}^{d \times h}\) and \(\boldsymbol{W}_{hc} \in \mathbb{R}^{h \times h}\) are weight parameters and \(\boldsymbol{b}_c \in \mathbb{R}^{1 \times h}\) is a bias parameter.

7.9.1.3. Memory Cells

We can control flow of information in the hidden state use input, forget, and output gates with an element value range between \([0, 1]\). This is also generally achieved by using multiplication by element (symbol \(\odot\)). The computation of the current time step memory cell \(\boldsymbol{C}_t \in \mathbb{R}^{n \times h}\) combines the information of the previous time step memory cells and the current time step candidate memory cells, and controls the flow of information through forget gate and input gate:

\[\boldsymbol{C}_t = \boldsymbol{F}_t \odot \boldsymbol{C}_{t-1} + \boldsymbol{I}_t \odot \tilde{\boldsymbol{C}}_t.\]

As shown in Figure 6.9, the forget gate controls whether the information in the memory cell \(\boldsymbol{C}_{t-1}\) of the last time step is passed to the current time step, and the input gate can control how the input of the current time step \(\boldsymbol{X}_t\) flows into the memory cells of the current time step through the candidate memory cell \(\tilde{\boldsymbol{C}}_t\). If the forget gate is always approximately 1 and the input gate is always approximately 0, the past memory cells will be saved over time and passed to the current time step. This design can cope with the vanishing gradient problem in recurrent neural networks and better capture dependencies for time series with large time step distances.

Computation of memory cells in LSTM. Here, the multiplication sign indicates multiplication by element.

Fig. 7.11 Computation of memory cells in LSTM. Here, the multiplication sign indicates multiplication by element.

7.9.1.4. Hidden States

With memory cells, we can also control the flow of information from memory cells to the hidden state \(\boldsymbol{H}_t \in \mathbb{R}^{n \times h}\) through the output gate:

\[\boldsymbol{H}_t = \boldsymbol{O}_t \odot \text{tanh}(\boldsymbol{C}_t).\]

The tanh function here ensures that the hidden state element value is between -1 and 1. It should be noted that when the output gate is approximately 1, the memory cell information will be passed to the hidden state for use by the output layer; and when the output gate is approximately 0, the memory cell information is only retained by itself. Figure 6.10 shows the computation of the hidden state in LSTM.

Computation of hidden state in LSTM. Here, the multiplication sign indicates multiplication by element.

Fig. 7.12 Computation of hidden state in LSTM. Here, the multiplication sign indicates multiplication by element.

7.9.2. Read the Data Set

Below we begin to implement and display LSTM. As with the experiments in the previous sections, we still use the lyrics of the Jay Chou data set to train the model to write lyrics.

In [1]:
import sys
sys.path.insert(0, '..')

import d2l
from mxnet import nd
from mxnet.gluon import rnn

(corpus_indices, char_to_idx, idx_to_char,
 vocab_size) = d2l.load_data_time_machine()

7.9.3. Implementation from Scratch

First, we will discuss how to implement LSTM from scratch.

7.9.3.1. Initialize Model Parameters

The code below initializes the model parameters. The hyper-parameter num_hiddens defines the number of hidden units.

In [2]:
num_inputs, num_hiddens, num_outputs = vocab_size, 256, vocab_size
ctx = d2l.try_gpu()

def get_params():
    def _one(shape):
        return nd.random.normal(scale=0.01, shape=shape, ctx=ctx)

    def _three():
        return (_one((num_inputs, num_hiddens)),
                _one((num_hiddens, num_hiddens)),
                nd.zeros(num_hiddens, ctx=ctx))

    W_xi, W_hi, b_i = _three()  # Input gate parameters
    W_xf, W_hf, b_f = _three()  # Forget gate parameters
    W_xo, W_ho, b_o = _three()  # Output gate parameters
    W_xc, W_hc, b_c = _three()  # Candidate cell parameters
    # Output layer parameters
    W_hq = _one((num_hiddens, num_outputs))
    b_q = nd.zeros(num_outputs, ctx=ctx)
    # Create gradient
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
              b_c, W_hq, b_q]
    for param in params:
        param.attach_grad()
    return params

7.9.4. Define the Model

In the initialization function, the hidden state of the LSTM needs to return an additional memory cell with a value of 0 and a shape of (batch size, number of hidden units).

In [3]:
def init_lstm_state(batch_size, num_hiddens, ctx):
    return (nd.zeros(shape=(batch_size, num_hiddens), ctx=ctx),
            nd.zeros(shape=(batch_size, num_hiddens), ctx=ctx))

Below, we defined the model based on LSTM computing expressions. It should be noted that only the hidden state will be passed into the output layer, and the memory cells do not participate in the computation of the output layer.

In [4]:
def lstm(inputs, state, params):
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
     W_hq, b_q] = params
    (H, C) = state
    outputs = []
    for X in inputs:
        I = nd.sigmoid(nd.dot(X, W_xi) + nd.dot(H, W_hi) + b_i)
        F = nd.sigmoid(nd.dot(X, W_xf) + nd.dot(H, W_hf) + b_f)
        O = nd.sigmoid(nd.dot(X, W_xo) + nd.dot(H, W_ho) + b_o)
        C_tilda = nd.tanh(nd.dot(X, W_xc) + nd.dot(H, W_hc) + b_c)
        C = F * C + I * C_tilda
        H = O * C.tanh()
        Y = nd.dot(H, W_hq) + b_q
        outputs.append(Y)
    return outputs, (H, C)

7.9.4.1. Train the Model and Write Lyrics

As in the previous section, during model training, we only use adjacent sampling. After setting the hyper-parameters, we train and model and create a 50 character string of lyrics based on the prefixes “separate” and “not separated”.

In [5]:
num_epochs, num_steps, batch_size, lr, clipping_theta = 160, 35, 32, 1e2, 1e-2
pred_period, pred_len, prefixes = 40, 50, ['traveller', 'time traveller']

We create a string of lyrics based on the currently trained model every 40 epochs.

In [6]:
d2l.train_and_predict_rnn(lstm, get_params, init_lstm_state, num_hiddens,
                          vocab_size, ctx, corpus_indices, idx_to_char,
                          char_to_idx, False, num_epochs, num_steps, lr,
                          clipping_theta, batch_size, pred_period, pred_len,
                          prefixes)
epoch 40, perplexity 7.842470, time 0.84 sec
 - traveller are alle the the the the the the the the the the
 - time traveller are alle the the the the the the the the the the
epoch 80, perplexity 3.947809, time 0.84 sec
 - traveller this that sureat in a mather the fire than this i
 - time traveller this that space the time traveller this is all th
epoch 120, perplexity 1.945280, time 0.84 sec
 - traveller doun an a remuthed.  'not the freer thing. that i
 - time traveller, with a slight accession of chate has only two di
epoch 160, perplexity 1.291660, time 0.84 sec
 - traveller. but some three dimensions, are passing along the
 - time traveller.  'it's as i ho expeestines to the grave. but hav

7.9.5. Concise Implementation

In Gluon, we can directly call the LSTM class in the rnn module.

In [7]:
lstm_layer = rnn.LSTM(num_hiddens)
model = d2l.RNNModel(lstm_layer, vocab_size)
d2l.train_and_predict_rnn_gluon(model, num_hiddens, vocab_size, ctx,
                                corpus_indices, idx_to_char, char_to_idx,
                                num_epochs, num_steps, lr, clipping_theta,
                                batch_size, pred_period, pred_len, prefixes)
epoch 40, perplexity 8.385296, time 0.05 sec
 - traveller and the timention time time time time time time t
 - time traveller the timention time time time time time time time
epoch 80, perplexity 4.518857, time 0.05 sec
 - traveller the pereattently coundimented the psecthoung the
 - time traveller thing the time traveller thing the time traveller
epoch 120, perplexity 2.272366, time 0.05 sec
 - traveller the time traveller could his not they that surmic
 - time traveller courded the littler three dimensions of space, bu
epoch 160, perplexity 1.426391, time 0.05 sec
 - traveller the time traveller held in his hand was a glitter
 - time traveller, 'ithar as in always foun armosing wooned and for

7.9.6. Summary

  • The hidden layer output of LSTM includes hidden states and memory cells. Only hidden states are passed into the output layer.
  • The input, forget, and output gates in LSTM can control the flow of information.
  • LSTM can cope with the gradient attenuation problem in the recurrent neural networks and better capture dependencies for time series with large time step distances.

7.9.7. Exercises

  • Adjust the hyper-parameters and observe and analyze the impact on running time, perplexity, and the written lyrics.
  • Under the same conditions, compare the running time of an LSTM, GRU and recurrent neural network without gates.
  • Since the candidate memory cells ensure that the value range is between -1 and 1 using the tanh function, why does the hidden state need to use the tanh function again to ensure that the output value range is between -1 and 1?

7.9.8. References

[1] Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory. Neural computation, 9(8), 1735-1780.

7.9.9. Scan the QR Code to Discuss

image0