{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model Selection, Underfitting and Overfitting\n",
"\n",
"As machine learning scientists, our goal is to discover general patterns. \n",
"Say, for example, that we wish to learn the pattern\n",
"that associates genetic markers with the development of dementia in adulthood. \n",
"It's easy enough to memorize our training set. \n",
"Each person's genes uniquely identify them, \n",
"not just among people represented in our dataset,\n",
"but among all people on earth!\n",
"\n",
"Given the genetic markers representing a some person,\n",
"we don't want our model to simply recognize \"oh, that's Bob\",\n",
"and then output the classification, \n",
"say among {*dementia*, *mild cognitive impairment*, *healthy*}, \n",
"that corresponds to Bob.\n",
"Rather, our goal is to discover patterns \n",
"that capture regularities in the underlying population\n",
"from which our training set was drawn.\n",
"If we are successfuly in this endeavour,\n",
"then we can could successfully assess risk\n",
"even for individuals that we have never encountered before.\n",
"This problem---how to disover patterns that *generalize*---is \n",
"the fundamental problem of machine learning.\n",
"\n",
"\n",
"The danger is that when we train models, \n",
"we access just a small sample of data. \n",
"The largest public image datasets contain roughly one million images.\n",
"And more often we have to learn from thousands or tens of thousands.\n",
"In a large hospital system we might access \n",
"hundreds of thousands of medical records. \n",
"With finite samples, we always run the risk\n",
"that we might discover *apparent* associations \n",
"that turn out not to hold up when we collect more data.\n",
"\n",
"Let’s consider an extreme pathological case. \n",
"Imagine that you want to learn to predict \n",
"which people will repay their loans.\n",
"A lender hires you as a data scientist to investigate,\n",
"handing over the complete files on 100 applicants, \n",
"5 of which defaulted on their loans within 3 years. \n",
"Realistically, the files might include hundreds of potential features, including income, occupation, credit score, length of employment etc. \n",
"Moreover, say that they additionally hand over video footage \n",
"of each applicant's interview with their lending agent. \n",
"\n",
"Now suppose that after featurizing the data into an enormous design matrix, \n",
"you discover that of the 5 applicants who default, \n",
"all of them were wearing blue shirts during their interviews, \n",
"while only 40% of general population wore blue shirts. \n",
"There's a good chance that if you train a predictive model\n",
"to predict default, it might rely upon blue-shirt-wearing\n",
"as an important feature.\n",
"\n",
"Even if in fact defaulters were no more likely to wear blue shirts\n",
"than people in the general population,\n",
"there’s a $.4^5 = .01$ probability that \n",
"we would observe all five defaulters wearing blue shirts.\n",
"With just $5$ positive examples of defaults \n",
"and hundreds or thousands of features, \n",
"we would probably find a large number of features \n",
"that appear to be perfectly predictive of our labor\n",
"just due to random chance. \n",
"With an unlimited amount of data, we would expect \n",
"these *spurious* associations to eventually disappear.\n",
"But we seldom have that luxury.\n",
"\n",
"The phenomena of fitting our training data \n",
"more closely than we fit the underlying distribution is called overfitting, and the techniques used to combat overfitting are called regularization.\n",
"In the previous sections, you might have observed \n",
"this effect while experimenting with the Fashion-MNIST dataset. \n",
"If you altered the model structure or the hyper-parameters during the experiment, you might have noticed that with enough nodes, layers, and training epochs, the model can eventually reach perfect accuracy on the training set, even as the accuracy on test data deteriorates. \n",
"\n",
"\n",
"## Training Error and Generalization Error\n",
"\n",
"In order to discuss this phenomenon more formally,\n",
"we need to differentiate between *training error* and *generalization error*.\n",
"The training error is the error of our model \n",
"as calculated on the training data set, \n",
"while generalization error is the expectation of our model's error\n",
"were we to apply it to an infinite stream of additional data points \n",
"drawn from the same underlying data distribution as our original sample. \n",
"\n",
"Problematically, *we can never calculate the generalization error exactly*.\n",
"That's because the imaginary stream of inifinit data is an imaginary object.\n",
"In practice, we must *estimate* the generalization error\n",
"by applying our model to an independent test set\n",
"constituted of a random selection of data points \n",
"that were withheld from our training set. \n",
"\n",
"The following three thought experiments \n",
"will help illustrate this situation better. \n",
"Consider a college student trying to prepare for his final exam. \n",
"A diligent student will strive to practice well \n",
"and test her abilities using exams from previous years. \n",
"Nonetheless, doing well on past exams is no guarantee \n",
"that she will excel when it matters. \n",
"For instance, the student might try to prepare \n",
"by rote learning the answers to the exam questions. \n",
"This requires the student to memorize many things. \n",
"She might even remember the answers for past exams perfectly. \n",
"Another student might prepare by trying to understand \n",
"the reasons for giving certain answers. \n",
"In most cases, the latter student will do much better.\n",
"\n",
"Likewise, consider a model that simply uses a lookup table to answer questions. If the set of allowable inputs is discrete and reasonably small, then perhaps after viewing *many* training examples, this approach would perform well. Still this model has no ability to do better than random guessing when faced with exampels that it has never seen before. \n",
"In reality the input spaces are far too large to memorize the answers corresponding to every conceivable input. For example, consider the black and white $28\\times28$ images. If each pixel can take one among $256$ gray scale values, then there are $256^784$ possible images. That means that there are far more low-res grayscale thumbnail-sized images than there are atoms in the universe. Even if we could encounter this data, we could never afford to store the lookuptable.\n",
"\n",
"Lastly, consider the problem of trying \n",
"to classify the outcomes of coin tosses (class 0: heads, class 1: tails)\n",
"based on some contextual features that might be available. \n",
"No matter what what algorithm we come up with,\n",
"because the generalization error will always be $\\frac{1}{2}$. \n",
"However, for most algorithms, \n",
"we should expect our training error to be considerably lower, \n",
"depending on the luck of the draw,\n",
"even if we didn't have any features!\n",
"Consder the dataset {0, 1, 1, 1, 0, 1}.\n",
"Our feature-less would have to fall back on always predicting \n",
"the *majority class*, which appears from our limited sample to be *1*.\n",
"In this case, the model that always predicts class 1 \n",
"will incur an error of $\\frac{1}{3}$,\n",
"considerably better than our generalization error. \n",
"As we increase the amount of data, \n",
"the probability that the fraction of heads \n",
"will deviate significantly from $\\frac{1}{2}$ diminishes,\n",
"and our training error would come to match the generalization error.\n",
"\n",
"\n",
"### Statistical Learning Theory\n",
"\n",
"Since generalization is the fundamental problem in machine learning,\n",
"you might not be surprised to learn \n",
"that many mathematicians and theorists have dedicated their lives \n",
"to developing formal theories to describe this phenomenon. \n",
"In their [eponymous theorem](https://en.wikipedia.org/wiki/Glivenko%E2%80%93Cantelli_theorem), Glivenko and Cantelli \n",
"derived the rate at which the training error \n",
"converges to the generalization error. \n",
"In a series of seminal papers, [Vapnik and Chervonenkis](https://en.wikipedia.org/wiki/Vapnik%E2%80%93Chervonenkis_theory) \n",
"extended this theory to more general classes of functions. \n",
"This work laid the foundations of [Statistical Learning Theory](https://en.wikipedia.org/wiki/Statistical_learning_theory).\n",
"\n",
"\n",
"In the **standard supervised learning setting**, which we have addressed up until now and will stick throughout most of this book,\n",
"we assume that both the training data and the test data \n",
"are drawn *independently* from *identical* distributions \n",
"(commonly called the i.i.d. assumption).\n",
"This means that when whatever process samples our data has no *memory*. \n",
"The 2nd example drawn and the 3rd drawn \n",
"are no more correlated than the 2nd and the 2-millionth sample drawn.\n",
"\n",
"Being a good machine learning scientist requires thinking critically,\n",
"and already you should be poking holes in this assumption,\n",
"coming up with common cases where the assumption fails. \n",
"What if we train a mortality risk predictor \n",
"on data collected from patients at UCSF,\n",
"and apply it on patients at Massachusetts General Hospital? \n",
"These distributions are simply not identical.\n",
"Moreover, draws might be correlated in time. \n",
"What if we are classifying the topics of Tweets.\n",
"The news cycle would create temporal dependencies \n",
"in the topics being discussed violating any assumptions of independence.\n",
"\n",
"Sometimes we can get away with minor violations of the i.i.d. assumption \n",
"and our models will continue to work remarkably well.\n",
"After all, nearly every real-world application \n",
"involves at least some minor violation of the i.i.d. assumption,\n",
"and yet we have useful tools for face recognition,\n",
"speech recognition, language translation, etc.\n",
"\n",
"Other violations are sure to cause trouble.\n",
"Imagine, ofr example, if we tried to train \n",
"a face recognition system by training it \n",
"exclusivey on university students \n",
"and then and then want to deploy it as a tool\n",
"for monitoring geriatrics in a nursing home population. \n",
"This is unlikely to work well since college students \n",
"tend to look considerably different from the elderly. \n",
"\n",
"In subsequent chapters and volumes, we will discuss problems \n",
"arising from violations of the i.i.d. assumption. \n",
"For now, even taking the i.i.d. assumption for granted, \n",
"understanding generalization is a formidable problem.\n",
"Moreover, elucidating the precise theoretical foundations\n",
"that might explain why deep neural networks generalize as well as they do\n",
"continues to vexes the greatest minds in learning theory.\n",
"\n",
"When we train our models, we attempt are searching for a function \n",
"that fits the training data as well as possible. \n",
"If the function is so flexible that it can catch on to spurious patterns \n",
"just as easily as to the true associations, \n",
"then it might peform *too well* without producing a model \n",
"that generalizes well to unseen data.\n",
"This is precisely what we want to avoid (or at least control). \n",
"Many of the techniques in deep learning are heuristics and tricks\n",
"aimed at guarding against overfitting.\n",
"\n",
"### Model Complexity\n",
"\n",
"When we have simple models and abundant data, \n",
"we expect the generalization error to resemble the training error. \n",
"When we work with more complex models and fewer examples, \n",
"we expect the training error to go down but the generalization gap to grow. \n",
"What precisely constitutes model complexity is a complex matter. \n",
"Many factors govern whether a model will generalize well. \n",
"For example a model with more parameters might be considered more complex. \n",
"A model whose parameters can take a wider range of values \n",
"might be more complex. \n",
"Often with neural networks, we think of a model \n",
"that takes more training steps as more complex, \n",
"and one subject to *early stopping* as less complex.\n",
"\n",
"It can be difficult to compare the complexity among members \n",
"of substantially different model classes \n",
"(say a decision tree versus a neural network). \n",
"For now, a simple rule of thumb is quite useful: \n",
"A model that can readily explain arbitrary facts \n",
"is what statisticians view as complex, \n",
"whereas one that has only a limited expressive power \n",
"but still manages to explain the data well \n",
"is probably closer to the truth. \n",
"In philosophy, this is closely related to Popper’s \n",
"criterion of [falsifiability](https://en.wikipedia.org/wiki/Falsifiability) \n",
"of a scientific theory: a theory is good if it fits data \n",
"and if there are specific tests which can be used to disprove it. \n",
"This is important since all statistical estimation is \n",
"[post hoc](https://en.wikipedia.org/wiki/Post_hoc), \n",
"i.e. we estimate after we observe the facts, \n",
"hence vulnerable to the associated fallacy. \n",
"For now, we'll put the philosophy aside and stick to more tangible issues.\n",
"\n",
"In this chapter, to give you some intuition, \n",
"we’ll focus on a few factors that tend \n",
"to influence the generalizability of a model class:\n",
"\n",
"1. The number of tunable parameters. When the number of tunable parameters, sometimes called the *degrees of freedom*, is large, models tend to be more susceptible to overfitting.\n",
"1. The values taken by the parameters. When weights can take a wider range of values, models can be more susceptible to over fitting.\n",
"1. The number of training examples. It’s trivially easy to overfit a dataset containing only one or two examples even if your model is simple. But overfitting a dataset with millions of examples requires an extremely flexible model.\n",
"\n",
"\n",
"## Model Selection\n",
"\n",
"In machine learning, we usually select our final model \n",
"after evaluating several candidate models. \n",
"This process is called model selection.\n",
"Sometimes the models subject to comparison \n",
"are fundamentally different in nature \n",
"(say, decision trees vs linear models).\n",
"At other times, we are comparing\n",
"members of the same class of models \n",
"that have been trained with different hyperparameter settings. \n",
"\n",
"With multilayer perceptrons for example, \n",
"we may wish to compare models with \n",
"different numbers of hidden layers,\n",
"different numbers of hidden units,\n",
"and various choices of the activation functions \n",
"applied to each hidden layer. \n",
"In order to determine the best among our candidate models, \n",
"we will typically employ a validation set.\n",
"\n",
"\n",
"### Validation Data Set\n",
"\n",
"In principle we should not touch our test set\n",
"until after we have chosen our all our hyper-parameters.\n",
"Were we to use the test data in the model selection process,\n",
"there's a risk that we might overfit the test data.\n",
"Then we would be in serious trouble. \n",
"If we over fit our training data, \n",
"there's always the evaluation on test data to keep us honest.\n",
"But if we overfit the test data, how would we ever know?\n",
"\n",
"\n",
"Thus, we should never rely on the test data for model selection.\n",
"And yet we cannot rely solely on the training data\n",
"for model selection either because \n",
"we cannot estimate the generalization error\n",
"on the very data that we use to train the model.\n",
"\n",
"The common practice to address this problem\n",
"is to split our data three ways,\n",
"incorporating a *validation set*\n",
"in addition to the training and test sets.\n",
"\n",
"\n",
"In practical applications, the picture gets muddier.\n",
"While ideally we would only touch the test data once,\n",
"to assess the very best model or to compare \n",
"a small number of models to each other,\n",
"real-world test data is seldom discarded after just one use.\n",
"We can seldom afford a new test set for each round of experiments.\n",
"\n",
"The result is a murky practice where the boundaries \n",
"between validation and test data are worryingly ambiguous. \n",
"Unless explicitly stated otherwise, in the experiments in this book\n",
"we are really working with what should rightly be called \n",
"training data and validation data, with no true test sets.\n",
"Therefore, the accuracy reported in each experiment\n",
"is really the validation accuracy and not a true test set accuracy. \n",
"The good news is that we don't need too much data in the validation set. \n",
"The uncertainty in our estimates can be shown \n",
"to be of the order of $O(n^{-\\frac{1}{2}})$.\n",
"\n",
"\n",
"### $K$-Fold Cross-Validation\n",
"\n",
"When training data is scarce, \n",
"we might not even be able to afford to hold out \n",
"enough data to constitute a proper validation set.\n",
"One popular solution to this problem is to employ\n",
"*$K$-fold cross-validation*. \n",
"Here, the original training data is split into $K$ non-overlapping subsets.\n",
"Then model training and validation are executed $K$ times,\n",
"each time training on $K-1$ subsets and validating \n",
"on a different subset (the one not used for training in that round).\n",
"Finally, the training and validation error rates are estimated \n",
"by averaging over the results from the $K$ experiments.\n",
"\n",
"\n",
"## Underfitting or Overfitting?\n",
"\n",
"When we compare the training and validation errors,\n",
"we want to be mindful of two common situations:\n",
"First, we want to watch out for cases\n",
"when our training error and validation error are both substantial\n",
"but there is a little gap between them.\n",
"If the model is unable to reduce the training error,\n",
"that could mean that our model is too simple \n",
"(i.e., insufficiently expressive)\n",
"to capture the pattern that we are trying to model.\n",
"Moreover, since the *generalziation gap* \n",
"between our training and valdiation errors is small,\n",
"we have reason to believe that we could get away with a more complex model.\n",
"This phenomenon is known as underfitting. \n",
"\n",
"On the other hand, as we discussed above,\n",
"we want to watch out for the cases \n",
"when our training error is significantly lower \n",
"than our calidation error, indicating severe overfitting. \n",
"Note that overfitting is not always a bad thing.\n",
"With deep learning especially, it's well known \n",
"that the best predictive models often perform \n",
"far better on training data than on holdout data.\n",
"Ultimately, we usually care more about the validation error\n",
"than about the gap between the training and validation errors.\n",
"\n",
"Whether we ovefit or underfit can depend \n",
"both on the complexity of our model \n",
"and the size of the available training datasets, \n",
"two topics that we discuss below.\n",
"\n",
"### Model Complexity\n",
"\n",
"To illustrate some classical intuition \n",
"about overfitting and model complexity,\n",
"we given an example using polynomials. \n",
"Given training data consisting of a single feature $x$ \n",
"and a corresponding real-valued label $y$, \n",
"we try to find the polynomial of degree $d$\n",
"\n",
"$$\\hat{y}= \\sum_{i=0}^d x^i w_i$$\n",
"\n",
"to estimate the labels $y$. \n",
"This is just a linear regression problem\n",
"where our featrues are given by the powers of $x$,\n",
"the $w_i$ given the model’s weights,\n",
"and the bias is given by $w_0$ since $x^0 = 1$ for all $x$. \n",
"Since this is just a linear regression problem,\n",
"we can use the squared error as our loss function.\n",
"\n",
"\n",
"A higher-order polynomial function is more complex \n",
"than a lower order polynomial function, \n",
"since the higher-order polynomial has more parameters \n",
"and the model function’s selection range is wider. \n",
"Fixing the training data set, \n",
"higher-order polynomial functions should always\n",
"achieve lower (at worst, equal) training error \n",
"relative to lower degree polynomials. \n",
"In fact, whenever the data points each have a distinct value of $x$,\n",
"a polynomial function with degree equal to the number of data points\n",
"can fit the training set perfectly. \n",
"We visualize the relationship between polynomial degree\n",
"and under- vs over-fitting below.\n",
"\n",
"![Influence of Model Complexity on Underfitting and Overfitting](../img/capacity_vs_error.svg)\n",
"\n",
"\n",
"### Data Set Size\n",
"\n",
"The other big consideration to bear in mind is the dataset size.\n",
"Fixing our model, the fewer samples we have in the training dataset, \n",
"the more likely (and more serverely) we are to encounter overfitting.\n",
"As we increase the amount of training data, \n",
"the generalization error typically decreases. \n",
"Moreover, in general, more data never hurts. \n",
"For a fixed task and data *distribution*,\n",
"there is typically a relationship between model complexity and dataset size.\n",
"Given more data, we might profitably attempt to fit a more complex model.\n",
"Absent sufficient data, simpler models may be difficult to beat.\n",
"For many tasks, deep learning only outperforms linear models\n",
"when many thousands of training examples are available.\n",
"In part, the current success of deep learning\n",
"owes to the current abundance of massive datasets \n",
"due to internet companies, cheap storage, connected devices,\n",
"and the broad digitization of the economy. \n",
"\n",
"## Polynomial Regression\n",
"\n",
"We can now explore these concepts interactively \n",
"by fitting polynomials to data. \n",
"To get started we'll import our usual packages."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "1"
}
},
"outputs": [],
"source": [
"import sys\n",
"sys.path.insert(0, '..')\n",
"\n",
"%matplotlib inline\n",
"import d2l\n",
"from mxnet import autograd, gluon, nd\n",
"from mxnet.gluon import data as gdata, loss as gloss, nn"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Generating Data Sets\n",
"\n",
"First we need data. Given $x$, we will use the following cubic polynomial to generate the labels on training and test data:\n",
"\n",
"$$y = 5 + 1.2x - 3.4\\frac{x^2}{2!} + 5.6 \\frac{x^3}{3!} + \\epsilon \\text{ where }\n",
"\\epsilon \\sim \\mathcal{N}(0,0.1)$$\n",
"\n",
"The noise term $\\epsilon$ obeys a normal distribution \n",
"with a mean of 0 and a standard deviation of 0.1. \n",
"We'll synthesize 100 samples each for the training set and test set."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "2"
}
},
"outputs": [],
"source": [
"maxdegree = 20 # Maximum degree of the polynomial\n",
"n_train, n_test = 100, 100 # Training and test data set sizes\n",
"true_w = nd.zeros(maxdegree) # Allocate lots of empty space\n",
"true_w[0:4] = nd.array([5, 1.2, -3.4, 5.6])\n",
"\n",
"features = nd.random.normal(shape=(n_train + n_test, 1))\n",
"features = nd.random.shuffle(features)\n",
"poly_features = nd.power(features, nd.arange(maxdegree).reshape((1, -1)))\n",
"poly_features = poly_features / (\n",
" nd.gamma(nd.arange(maxdegree) + 1).reshape((1, -1)))\n",
"labels = nd.dot(poly_features, true_w)\n",
"labels += nd.random.normal(scale=0.1, shape=labels.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For optimization, we typically want to avoid \n",
"very large values of gradients, losses, etc.\n",
"This is why the monomials stored in `poly_features` \n",
"are rescaled from $x^i$ to $\\frac{1}{i!} x^i$. \n",
"It allows us to avoid very large values for large exponents $i$. \n",
"Factorials are implemented in Gluon using the Gamma function, \n",
"where $n! = \\Gamma(n+1)$.\n",
"\n",
"Take a look at the first 2 samples from the generated data set. \n",
"The value 1 is technically a feature, \n",
"namely the constant feature corresponding to the bias."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "3"
}
},
"outputs": [
{
"data": {
"text/plain": [
"(\n",
" [[1.5094751]\n",
" [1.9676613]]\n",
" , \n",
" [[1.0000000e+00 1.5094751e+00 1.1392574e+00 5.7322693e-01 2.1631797e-01\n",
" 6.5305315e-02 1.6429458e-02 3.5428370e-03 6.6847802e-04 1.1211676e-04\n",
" 1.6923748e-05 2.3223611e-06 2.9212887e-07 3.3920095e-08 3.6572534e-09\n",
" 3.6803552e-10 3.4721288e-11 3.0829944e-12 2.5853909e-13 2.0539909e-14]\n",
" [1.0000000e+00 1.9676613e+00 1.9358451e+00 1.2696959e+00 6.2458295e-01\n",
" 2.4579351e-01 8.0606394e-02 2.2658013e-02 5.5729118e-03 1.2184002e-03\n",
" 2.3973989e-04 4.2884261e-05 7.0318083e-06 1.0643244e-06 1.4958786e-07\n",
" 1.9622549e-08 2.4131586e-09 2.7931044e-10 3.0532687e-11 3.1619987e-12]]\n",
" , \n",
" [6.1804767 7.7595935]\n",
" )"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"features[:2], poly_features[:2], labels[:2]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Defining, Training and Testing Model\n",
"\n",
"We first define the plotting function`semilogy`, \n",
"where the $y$ axis makes use of the logarithmic scale."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "4"
}
},
"outputs": [],
"source": [
"# This function has been saved in the d2l package for future use\n",
"def semilogy(x_vals, y_vals, x_label, y_label, x2_vals=None, y2_vals=None,\n",
" legend=None, figsize=(3.5, 2.5)):\n",
" d2l.set_figsize(figsize)\n",
" d2l.plt.xlabel(x_label)\n",
" d2l.plt.ylabel(y_label)\n",
" d2l.plt.semilogy(x_vals, y_vals)\n",
" if x2_vals and y2_vals:\n",
" d2l.plt.semilogy(x2_vals, y2_vals, linestyle=':')\n",
" d2l.plt.legend(legend)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Since we will be attempting to fit the generated dataset \n",
"using models of varying complexity, \n",
"we insert the model definition into the `fit_and_plot` function. \n",
"The training and testing steps involved in polynomial function fitting \n",
"are similar to those previously described in softmax regression."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "5"
}
},
"outputs": [],
"source": [
"num_epochs, loss = 200, gloss.L2Loss()\n",
"\n",
"def fit_and_plot(train_features, test_features, train_labels, test_labels):\n",
" net = nn.Sequential()\n",
" # Switch off the bias since we already catered for it in the polynomial\n",
" # features\n",
" net.add(nn.Dense(1, use_bias=False))\n",
" net.initialize()\n",
" batch_size = min(10, train_labels.shape[0])\n",
" train_iter = gdata.DataLoader(gdata.ArrayDataset(\n",
" train_features, train_labels), batch_size, shuffle=True)\n",
" trainer = gluon.Trainer(net.collect_params(), 'sgd',\n",
" {'learning_rate': 0.01})\n",
" train_ls, test_ls = [], []\n",
" for _ in range(num_epochs):\n",
" for X, y in train_iter:\n",
" with autograd.record():\n",
" l = loss(net(X), y)\n",
" l.backward()\n",
" trainer.step(batch_size)\n",
" train_ls.append(loss(net(train_features),\n",
" train_labels).mean().asscalar())\n",
" test_ls.append(loss(net(test_features),\n",
" test_labels).mean().asscalar())\n",
" print('final epoch: train loss', train_ls[-1], 'test loss', test_ls[-1])\n",
" semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss',\n",
" range(1, num_epochs + 1), test_ls, ['train', 'test'])\n",
" print('weight:', net[0].weight.data().asnumpy())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Third-order Polynomial Function Fitting (Normal)\n",
"\n",
"We will begin by first using a third-order polynomial function \n",
"with the same order as the data generation function. \n",
"The results show that this model’s training error rate \n",
"when using the testing data set is low. \n",
"The trained model parameters are also close \n",
"to the true values $w = [5, 1.2, -3.4, 5.6]$."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"attributes": {
"classes": [],
"id": "",
"n": "6"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"final epoch: train loss 0.003981937 test loss 0.0038848561\n",
"weight: [[ 5.0161743 1.176134 -3.4185255 5.6422653]]\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
"