{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Weight Decay\n",
"\n",
"Now that we have characterized the problem of overfitting \n",
"and motivated the need for capacity control,\n",
"we can begin discussing some of the popular techniques\n",
"used to these ends in practice.\n",
"Recall that we can always mitigate overfitting \n",
"by going out and collecting more training data,\n",
"that can be costly and time consuming,\n",
"typically making it impossible in the short run.\n",
"For now, let's assume that we have already obtained \n",
"as much high-quality data as our resources permit\n",
"and focus on techniques aimed at limiting the capacity \n",
"of the function classes under consideration. \n",
"\n",
"In our toy example, \n",
"we saw that we could control the complexity of a polynomial \n",
"by adjusting its degree. \n",
"However, most of machine learning \n",
"does not consist of polynomial curve fitting.\n",
"And moreover, even when we focus on polynomial regression,\n",
"when we deal with high-dimensional data,\n",
"manipulating model capacity by tweaking the degree $d$ is problematic.\n",
"To see why, note that for multivariate data\n",
"we must generalize the concept of polynomials \n",
"to include *monomials*, which are simply\n",
"products of powers of variables.\n",
"For example, $x_1^2 x_2$, and $x_3 x_5^2$ are both monomials of degree $3$.\n",
"The number of such terms with a given degree $d$\n",
"blows up as a function of the degree $d$.\n",
"\n",
"Concretely, for vectors of dimensionality $D$,\n",
"the number of monomials of a given degree $d$ is ${D -1 + d} \\choose {D-1}$.\n",
"Hence, a small change in degree, even from say $1$ to $2$ or $2$ to $3$ \n",
"would entail a massive blowup in the complexity of our model.\n",
"Thus, tweaking the degree is too blunt a hammer.\n",
"Instead, we need a more fine-grained tool \n",
"for adjusting function complexity.\n",
"\n",
"## Squared Norm Regularization\n",
"\n",
"*Weight decay* (commonly called *L2* regularization), \n",
"might be the most widely-used technique \n",
"for regularizing parametric machine learning models.\n",
"The basic intuition behind weight decay is \n",
"the notion that among all functions $f$, \n",
"the function $f = 0$ is the simplest. \n",
"Intuitively, we can then measure functions by their proximity to zero. \n",
"But how precisely should we measure\n",
"the distance between a function and zero? \n",
"There is no single right answer.\n",
"In fact, entire branches of mathematics, \n",
"e.g. in functional analysis and the theory of Banach spaces\n",
"are devoted to answering this issue.\n",
"\n",
"For our present purposes, a very simple interpretation will suffice:\n",
"We will consider a linear function \n",
"$f(\\mathbf{x}) = \\mathbf{w}^\\top \\mathbf{x}$ \n",
"to be simple if its weight vector is small. \n",
"We can measure this via $||mathbf{w}||^2$. \n",
"One way of keeping the weight vector small \n",
"is to add its norm as a penalty term \n",
"to the problem of minimizing the loss. \n",
"Thus we replace our original objective, \n",
"*minimize the prediction error on the training labels*,\n",
"with new objective,\n",
"*minimize the sum of the prediction error and the penalty term*.\n",
"Now, if the weight vector becomes too large, \n",
"our learning algorithm will find more profit in\n",
"minimizing the norm $|| \\mathbf{w} ||^2$ \n",
"versus minimizing the training error. \n",
"That's exactly what we want. \n",
"To illustrate things in code, let's revive our previous example\n",
"from our chapter on [Linear Regression](linear-regression.md). \n",
"There, our loss was given by\n",
"\n",
"$$l(\\mathbf{w}, b) = \\frac{1}{n}\\sum_{i=1}^n \\frac{1}{2}\\left(\\mathbf{w}^\\top \\mathbf{x}^{(i)} + b - y^{(i)}\\right)^2.$$\n",
"\n",
"Recall that $\\mathbf{x}^{(i)}$ are the observations, \n",
"$y^{(i)}$ are labels, and $(\\mathbf{w}, b)$ \n",
"are the weight and bias parameters respectively. \n",
"To arrive at a new loss function \n",
"that penalizes the size of the weight vector, \n",
"we need to add $||mathbf{w}||^2$, but how much should we add? \n",
"To address this, we need to add a new hyperparameter,\n",
"that we will call the *regularization constant* and denote by $\\lambda$:\n",
"\n",
"$$l(\\mathbf{w}, b) + \\frac{\\lambda}{2} \\|\\boldsymbol{w}\\|^2$$\n",
"\n",
"This non-negatice parameter $\\lambda \\geq 0$ \n",
"governs the amount of regularization. \n",
"For $\\lambda = 0$, we recover our original loss function, \n",
"whereas for $\\lambda > 0$ we ensure that $\\mathbf{w}$ cannot grow too large. The astute reader might wonder why we are squaring \n",
"the norm of the weight vector. \n",
"We do this for two reasons.\n",
"First, we do it for computational convenience.\n",
"By squaring the L2 norm, we remove the square root,\n",
"leaving the sum of squares of each component of the weight vector.\n",
"This is convenient because it is easy to compute derivatives of a sum of terms (the sum of derivatives equals the derivative of the sum). \n",
"\n",
"Moreover, you might ask, why the L2 norm in the first place and not the L1 norm, or some other distance function.\n",
"In fact, several other choices are valid \n",
"and are popular throughout statistics.\n",
"While L2-regularized linear models constitute \n",
"the classic *ridge regression* algorithm\n",
"L1-regularizaed linear regression \n",
"is a similarly fundamental model in statistics \n",
"popularly known as *lasso regression*.\n",
"\n",
"One mathematical reason for working with the L2 norm and not some other norm,\n",
"is that it penalizes large components of the weight vector\n",
"much more than it penalizes small ones. \n",
"This encourages our learning algorithm to discover models \n",
"which distribute their weight across a larger number of features,\n",
"which might make them more robust in practice \n",
"since they do not depend precariously on a single feature.\n",
"The stochastic gradient descent updates for L2-regularied regression\n",
"are as follows:\n",
"\n",
"$$\n",
"\\begin{aligned}\n",
"w & \\leftarrow \\left(1- \\frac{\\eta\\lambda}{|\\mathcal{B}|} \\right) \\mathbf{w} - \\frac{\\eta}{|\\mathcal{B}|} \\sum_{i \\in \\mathcal{B}} \\mathbf{x}^{(i)} \\left(\\mathbf{w}^\\top \\mathbf{x}^{(i)} + b - y^{(i)}\\right),\n",
"\\end{aligned}\n",
"$$\n",
"\n",
"As before, we update $\\mathbf{w}$ based on the amount \n",
"by which our estimate differs from the observation. \n",
"However, we also shrink the size of $\\mathbf{w}$ towards $0$.\n",
"That's why the method is sometimes called \"weight decay\":\n",
"because the penalty term literally causes our optimization algorthm \n",
"to *decay* the magnitude of the weight at each step of training. \n",
"This is more convenient than having to pick \n",
"the number of parameters as we did for polynomials. \n",
"In particular, we now have a continuous mechanism \n",
"for adjusting the complexity of $f$. \n",
"Small values of $\\lambda$ correspond to unconstrained $\\mathbf{w}$,\n",
"whereas large values of $\\lambda$ constrain $\\mathbf{w}$ considerably. \n",
"Since we don't want to have large bias terms either, \n",
"we often add $b^2$ as a penalty, too.\n",
"\n",
"## High-dimensional Linear Regression\n",
"\n",
"For high-dimensional regression it is difficult \n",
"to pick the 'right' dimensions to omit. \n",
"Weight-decay regularization is a much more convenient alternative. \n",
"We will illustrate this below.\n",
"First, we will generate some synthetic data as before\n",
"\n",
"$$y = 0.05 + \\sum_{i = 1}^d 0.01 x_i + \\epsilon \\text{ where }\n",
"\\epsilon \\sim \\mathcal{N}(0, 0.01)$$\n",
"\n",
"representing our label as a linear function of our inputs,\n",
"corrupted by Gaussian noise with zero mean and variance 0.01. \n",
"To observe the effects of overfitting more easily,\n",
"we can make our problem high-dimensional,\n",
"setting the data dimension to $d = 200$ \n",
"and working with a relatively small number of training examplesâ€”here we'll set the sample size to 20:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "2"
}
},
"outputs": [],
"source": [
"import sys\n",
"sys.path.insert(0, '..')\n",
"\n",
"%matplotlib inline\n",
"import d2l\n",
"from mxnet import autograd, gluon, init, nd\n",
"from mxnet.gluon import data as gdata, loss as gloss, nn\n",
"\n",
"n_train, n_test, num_inputs = 20, 100, 200\n",
"true_w, true_b = nd.ones((num_inputs, 1)) * 0.01, 0.05\n",
"\n",
"features = nd.random.normal(shape=(n_train + n_test, num_inputs))\n",
"labels = nd.dot(features, true_w) + true_b\n",
"labels += nd.random.normal(scale=0.01, shape=labels.shape)\n",
"train_features, test_features = features[:n_train, :], features[n_train:, :]\n",
"train_labels, test_labels = labels[:n_train], labels[n_train:]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Implementation from Scratch\n",
"\n",
"Next, we will show how to implement weight decay from scratch. \n",
"All we have to do here is to add the squared $\\ell_2$ penalty \n",
"as an additional loss term added to the original target function. \n",
"The squared norm penalty derives its name from the fact \n",
"that we are adding the second power $\\sum_i w_i^2$. \n",
"The $\\ell_2$ is just one among an infinite class of norms call p-norms,\n",
"many of which you might encounter in the future.\n",
"In general, for some number $p$, the $\\ell_p$ norm is defined as\n",
"\n",
"$$\\|\\mathbf{w}\\|_p^p := \\sum_{i=1}^d |w_i|^p$$\n",
"\n",
"### Initialize Model Parameters\n",
"\n",
"First, we'll define a function to randomly initialize our model parameters and run `attach_grad` on each to allocate memory for the gradients we will calculate."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "5"
}
},
"outputs": [],
"source": [
"def init_params():\n",
" w = nd.random.normal(scale=1, shape=(num_inputs, 1))\n",
" b = nd.zeros(shape=(1,))\n",
" w.attach_grad()\n",
" b.attach_grad()\n",
" return [w, b]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define $\\ell_2$ Norm Penalty\n",
"\n",
"Perhaps the most convenient way to implement this penalty \n",
"is to square all terms in place and summ them up. \n",
"We divide by $2$ by convention\n",
"(when we take the derivative of a quadratic function,\n",
"the $2$ and $1/2$ cancel out, ensuring that the expression \n",
"for the update looks nice and simple)."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "6"
}
},
"outputs": [],
"source": [
"def l2_penalty(w):\n",
" return (w**2).sum() / 2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define Training and Testing\n",
"\n",
"The following code defines how to train and test the model \n",
"separately on the training data set and the test data set. \n",
"Unlike the previous sections, here, the $\\ell_2$ norm penalty term \n",
"is added when calculating the final loss function. \n",
"The linear network and the squared loss \n",
"haven't changed since the previous chapter, \n",
"so we'll just import them via `d2l.linreg` and `d2l.squared_loss` \n",
"to reduce clutter."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "7"
}
},
"outputs": [],
"source": [
"batch_size, num_epochs, lr = 1, 100, 0.003\n",
"net, loss = d2l.linreg, d2l.squared_loss\n",
"train_iter = gdata.DataLoader(gdata.ArrayDataset(\n",
" train_features, train_labels), batch_size, shuffle=True)\n",
"\n",
"def fit_and_plot(lambd):\n",
" w, b = init_params()\n",
" train_ls, test_ls = [], []\n",
" for _ in range(num_epochs):\n",
" for X, y in train_iter:\n",
" with autograd.record():\n",
" # The L2 norm penalty term has been added\n",
" l = loss(net(X, w, b), y) + lambd * l2_penalty(w)\n",
" l.backward()\n",
" d2l.sgd([w, b], lr, batch_size)\n",
" train_ls.append(loss(net(train_features, w, b),\n",
" train_labels).mean().asscalar())\n",
" test_ls.append(loss(net(test_features, w, b),\n",
" test_labels).mean().asscalar())\n",
" d2l.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss',\n",
" range(1, num_epochs + 1), test_ls, ['train', 'test'])\n",
" print('l2 norm of w:', w.norm().asscalar())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training without Regularization\n",
"\n",
"Next, let's train and test the high-dimensional linear regression model. \n",
"When `lambd = 0` we do not use weight decay. \n",
"As a result, while the training error decreases, the test error does not. \n",
"This is a perfect example of overfitting."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "8"
}
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
"