{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Long Short Term Memory (LSTM)\n",
"\n",
"The challenge to address long-term information preservation and short-term input skipping in latent variable models has existed for a long time. One of the earliest approaches to address this was the LSTM by [Hochreiter and Schmidhuber, 1997](http://papers.nips.cc/paper/1215-lstm-can-solve-hard-long-time-lag-problems.pdf). It shares many of the properties of the Gated Recurrent Unit (GRU) and predates it by almost two decades. Its design is slightly more complex. \n",
"\n",
"Arguably it is inspired by logic gates of a computer. To control a memory cell we need a number of gates. One gate is needed to read out the entries from the cell (as opposed to reading any other cell). We will refer to this as the *output* gate. A second gate is needed to decide when to read data into the cell. We refer to this as the *input* gate. Lastly, we need a mechanism to reset the contents of the cell, governed by a *forget* gate. The motivation for such a design is the same as before, namely to be able to decide when to remember and when to ignore inputs into the latent state via a dedicated mechanism. Let's see how this works in practice.\n",
"\n",
"## Gated Memory Cells\n",
"\n",
"Three gates are introduced in LSTMs: the input gate, the forget gate, and the output gate. In addition to that we introduce memory cells that take the same shape as the hidden state. Strictly speaking this is just a fancy version of a hidden state, custom engineered to record additional information. \n",
"\n",
"### Input Gates, Forget Gates and Output Gates\n",
"\n",
"Just like with GRUs, the data feeding into the LSTM gates is the input at the current time step $\\mathbf{X}_t$ and the hidden state of the previous time step $\\mathbf{H}_{t-1}$. These inputs are processed by a fully connected layer and a sigmoid activation function to compute the values of input, forget and output gates. As a result, the three gate elements all have a value range of $[0,1]$.\n",
"\n",
"![Calculation of input, forget, and output gates in an LSTM. ](../img/lstm_0.svg)\n",
"\n",
"We assume there are $h$ hidden units and that the minibatch is of size $n$. Thus the input is $\\mathbf{X}_t \\in \\mathbb{R}^{n \\times d}$ (number of examples: $n$, number of inputs: $d$）and the hidden state of the last time step is $\\mathbf{H}_{t-1} \\in \\mathbb{R}^{n \\times h}$. Correspondingly the gates are defined as follows: the input gate is $\\mathbf{I}_t \\in \\mathbb{R}^{n \\times h}$, the forget gate is $\\mathbf{F}_t \\in \\mathbb{R}^{n \\times h}$, and the output gate is $\\mathbf{O}_t \\in \\mathbb{R}^{n \\times h}$. They are calculated as follows:\n",
"\n",
"$$\n",
"\\begin{aligned}\n",
"\\mathbf{I}_t &= \\sigma(\\mathbf{X}_t \\mathbf{W}_{xi} + \\mathbf{H}_{t-1} \\mathbf{W}_{hi} + \\mathbf{b}_i),\\\\\n",
"\\mathbf{F}_t &= \\sigma(\\mathbf{X}_t \\mathbf{W}_{xf} + \\mathbf{H}_{t-1} \\mathbf{W}_{hf} + \\mathbf{b}_f),\\\\\n",
"\\mathbf{O}_t &= \\sigma(\\mathbf{X}_t \\mathbf{W}_{xo} + \\mathbf{H}_{t-1} \\mathbf{W}_{ho} + \\mathbf{b}_o),\n",
"\\end{aligned}\n",
"$$\n",
"\n",
"$\\mathbf{W}_{xi}, \\mathbf{W}_{xf}, \\mathbf{W}_{xo} \\in \\mathbb{R}^{d \\times h}$ and $\\mathbf{W}_{hi}, \\mathbf{W}_{hf}, \\mathbf{W}_{ho} \\in \\mathbb{R}^{h \\times h}$ are weight parameters and $\\mathbf{b}_i, \\mathbf{b}_f, \\mathbf{b}_o \\in \\mathbb{R}^{1 \\times h}$ are bias parameters.\n",
"\n",
"\n",
"### Candidate Memory Cell\n",
"\n",
"Next we design a memory cell. Since we haven't specified the action of the various gates yet, we first introduce a *candidate* memory cell $\\tilde{\\mathbf{C}}_t \\in \\mathbb{R}^{n \\times h}$. Its computation is similar to the three gates described above, but using a $\\tanh$ function with a value range for $[-1, 1]$ as activation function. This leads to the following equation at time step $t$.\n",
"\n",
"$$\\tilde{\\mathbf{C}}_t = \\text{tanh}(\\mathbf{X}_t \\mathbf{W}_{xc} + \\mathbf{H}_{t-1} \\mathbf{W}_{hc} + \\mathbf{b}_c)$$\n",
"\n",
"Here $\\mathbf{W}_{xc} \\in \\mathbb{R}^{d \\times h}$ and $\\mathbf{W}_{hc} \\in \\mathbb{R}^{h \\times h}$ are weights and $\\mathbf{b}_c \\in \\mathbb{R}^{1 \\times h}$ is a bias.\n",
"\n",
"![Computation of candidate memory cells in LSTM. ](../img/lstm_1.svg)\n",
"\n",
"\n",
"### Memory Cell\n",
"\n",
"In GRUs we had a single mechanism to govern input and forgetting. Here we have two parameters, $\\mathbf{I}_t$ which governs how much we take new data into account via $\\tilde{\\mathbf{C}}_t$ and the forget parameter $\\mathbf{F}_t$ which addresses how much we of the old memory cell content $\\mathbf{C}_{t-1} \\in \\mathbb{R}^{n \\times h}$ we retain. Using the same pointwise multiplication trick as before we arrive at the following update equation. \n",
"\n",
"$$\\mathbf{C}_t = \\mathbf{F}_t \\odot \\mathbf{C}_{t-1} + \\mathbf{I}_t \\odot \\tilde{\\mathbf{C}}_t.$$\n",
"\n",
"If the forget gate is always approximately 1 and the input gate is always approximately 0, the past memory cells will be saved over time and passed to the current time step. This design was introduced to alleviate the vanishing gradient problem and to better capture dependencies for time series with long range dependencies. We thus arrive at the following flow diagram.\n",
"\n",
"![Computation of memory cells in an LSTM. Here, the multiplication is carried out element-wise. ](../img/lstm_2.svg)\n",
"\n",
"\n",
"### Hidden States\n",
"\n",
"Lastly we need to define how to compute the hidden state $\\mathbf{H}_t \\in \\mathbb{R}^{n \\times h}$. This is where the output gate comes into play. In the LSTM it is simply a gated version of the $\\tanh$ of the memory cell. This ensures that the values of $\\mathbf{H}_t$ are always in the interval $[-1, 1]$. Whenever the output gate is $1$ we effectively pass all memory information through to the predictor whereas for output $0$ we retain all information only within the memory cell and perform no further processing. The figure below has a graphical illustration of the data flow.\n",
"\n",
"$$\\mathbf{H}_t = \\mathbf{O}_t \\odot \\tanh(\\mathbf{C}_t).$$\n",
"\n",
"![Computation of the hidden state. Multiplication is element-wise. ](../img/lstm_3.svg)\n",
"\n",
"\n",
"\n",
"\n",
"## Implementation from Scratch\n",
"\n",
"Now it's time to implement an LSTM. We begin with a model built from scratch. As with the experiments in the previous sections we first need to load the data. We use *The Time Machine* for this."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "1"
}
},
"outputs": [],
"source": [
"import sys\n",
"sys.path.insert(0, '..')\n",
"\n",
"import d2l\n",
"from mxnet import nd, init\n",
"from mxnet.gluon import rnn\n",
"\n",
"corpus_indices, vocab = d2l.load_data_time_machine()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Initialize Model Parameters\n",
"\n",
"Next we need to define and initialize the model parameters. As previously, the hyperparameter `num_hiddens` defines the number of hidden units. We initialize weights with a Gaussian with $0.01$ variance and we set the biases to $0$."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "2"
}
},
"outputs": [],
"source": [
"num_inputs, num_hiddens, num_outputs = len(vocab), 256, len(vocab)\n",
"ctx = d2l.try_gpu()\n",
"\n",
"def get_params():\n",
" def _one(shape):\n",
" return nd.random.normal(scale=0.01, shape=shape, ctx=ctx)\n",
"\n",
" def _three():\n",
" return (_one((num_inputs, num_hiddens)),\n",
" _one((num_hiddens, num_hiddens)),\n",
" nd.zeros(num_hiddens, ctx=ctx))\n",
"\n",
" W_xi, W_hi, b_i = _three() # Input gate parameters\n",
" W_xf, W_hf, b_f = _three() # Forget gate parameters\n",
" W_xo, W_ho, b_o = _three() # Output gate parameters\n",
" W_xc, W_hc, b_c = _three() # Candidate cell parameters\n",
" # Output layer parameters\n",
" W_hq = _one((num_hiddens, num_outputs))\n",
" b_q = nd.zeros(num_outputs, ctx=ctx)\n",
" # Create gradient\n",
" params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,\n",
" b_c, W_hq, b_q]\n",
" for param in params:\n",
" param.attach_grad()\n",
" return params"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define the Model\n",
"\n",
"In the initialization function, the hidden state of the LSTM needs to return an additional memory cell with a value of $0$ and a shape of (batch size, number of hidden units). Hence we get the following state initialization."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "3"
}
},
"outputs": [],
"source": [
"def init_lstm_state(batch_size, num_hiddens, ctx):\n",
" return (nd.zeros(shape=(batch_size, num_hiddens), ctx=ctx),\n",
" nd.zeros(shape=(batch_size, num_hiddens), ctx=ctx))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The actual model is defined just like we discussed it before with three gates and an auxiliary memory cell. Note that only the hidden state is passed on to the output layer. The memory cells do not participate in the computation directly."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "4"
}
},
"outputs": [],
"source": [
"def lstm(inputs, state, params):\n",
" [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,\n",
" W_hq, b_q] = params\n",
" (H, C) = state\n",
" outputs = []\n",
" for X in inputs:\n",
" I = nd.sigmoid(nd.dot(X, W_xi) + nd.dot(H, W_hi) + b_i)\n",
" F = nd.sigmoid(nd.dot(X, W_xf) + nd.dot(H, W_hf) + b_f)\n",
" O = nd.sigmoid(nd.dot(X, W_xo) + nd.dot(H, W_ho) + b_o)\n",
" C_tilda = nd.tanh(nd.dot(X, W_xc) + nd.dot(H, W_hc) + b_c)\n",
" C = F * C + I * C_tilda\n",
" H = O * C.tanh()\n",
" Y = nd.dot(H, W_hq) + b_q\n",
" outputs.append(Y)\n",
" return outputs, (H, C)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training and Prediction\n",
"\n",
"As in the previous section, during model training, we only use adjacent sampling. After setting the hyper-parameters, we train and model and create a 50 character string of text based on the prefixes \"traveller\" and \"time traveller\"."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "9"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 25, perplexity 13.632244, time 37.31 sec\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 50, perplexity 10.292623, time 17.40 sec\n",
" - travellere the the the the the the the the the the the the \n",
" - time travellere the the the the the the the the the the the the \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 75, perplexity 8.054282, time 17.28 sec\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 100, perplexity 6.030462, time 17.25 sec\n",
" - traveller the the the the time travelly the the the the tim\n",
" - time traveller the the the the time travelly the the the the tim\n"
]
}
],
"source": [
"num_epochs, num_steps, batch_size, lr, clipping_theta = 100, 35, 32, 3, 1\n",
"prefixes = ['traveller', 'time traveller']\n",
"\n",
"d2l.train_and_predict_rnn(lstm, get_params, init_lstm_state, num_hiddens,\n",
" corpus_indices, vocab, ctx, False, num_epochs, \n",
" num_steps, lr, clipping_theta, batch_size, prefixes)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Concise Implementation\n",
"\n",
"In Gluon, we can call the `LSTM` class in the `rnn` module directly to instantiate the model."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "10"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 125, perplexity 4.431575, time 4.45 sec\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 250, perplexity 1.074528, time 4.47 sec\n",
" - traveller smiled. 'are you sure we can move freely in space\n",
" - time traveller smiled. 'are you sure we can move freely in space\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 375, perplexity 1.039375, time 4.44 sec\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 500, perplexity 1.039455, time 4.34 sec\n",
" - traveller smiled. 'are you sure we can move freely in space\n",
" - time traveller smiled. 'are you sure we can move freely in space\n"
]
}
],
"source": [
"lstm_layer = rnn.LSTM(num_hiddens)\n",
"model = d2l.RNNModel(lstm_layer, len(vocab))\n",
"d2l.train_and_predict_rnn_gluon(model, num_hiddens, corpus_indices, vocab, \n",
" ctx, num_epochs*5, num_steps, lr, \n",
" clipping_theta, batch_size, prefixes)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Summary\n",
"\n",
"* LSTMs have three types of gates: input, forget and output gates which control the flow of information. \n",
"* The hidden layer output of LSTM includes hidden states and memory cells. Only hidden states are passed into the output layer. Memory cells are entirely internal.\n",
"* LSTMs can help cope with vanishing and exploding gradients due to long range dependencies and short-range irrelevant data. \n",
"* In many cases LSTMs perform slightly better than GRUs but they are more costly to train and execute due to the larger latent state size. \n",
"* LSTMs are the prototypical latent variable autoregressive model with nontrivial state control. Many variants thereof have been proposed over the years, e.g. multiple layers, residual connections, different types of regularization.\n",
"* Training LSTMs and other sequence models is quite costly due to the long dependency of the sequence. Later we will encounter alternative models such as transformers that can be used in some cases.\n",
"\n",
"## Exercises\n",
"\n",
"1. Adjust the hyperparameters. Observe and analyze the impact on runtime, perplexity, and the generted output. \n",
"1. How would you need to change the model to generate proper words as opposed to sequences of characters?\n",
"1. Compare the computational cost for GRUs, LSTMs and regular RNNs for a given hidden dimension. Pay special attention to training and inference cost\n",
"1. Since the candidate memory cells ensure that the value range is between -1 and 1 using the tanh function, why does the hidden state need to use the tanh function again to ensure that the output value range is between -1 and 1?\n",
"1. Implement an LSTM for time series prediction rather than character sequences. \n",
"\n",
"\n",
"## References\n",
"\n",
"[1] Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory. Neural computation, 9(8), 1735-1780.\n",
"\n",
"## Scan the QR Code to [Discuss](https://discuss.mxnet.io/t/2368)\n",
"\n",
"![](../img/qr_lstm.svg)"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}