Custom Layers
=============
One factor behind deep learning’s success is the availability of a wide
range of layers that can be composed in creative ways to design
architectures suitable for a wide variety of tasks. For instance,
researchers have invented layers specifically for handling images, text,
looping over sequential data, and performing dynamic programming. Sooner
or later, you will need a layer that does not exist yet in the deep
learning framework. In these cases, you must build a custom layer. In
this section, we show you how.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from mxnet import np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l
npx.set_np()
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import jax
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import tensorflow as tf
from d2l import tensorflow as d2l
.. raw:: html
.. raw:: html
Layers without Parameters
-------------------------
To start, we construct a custom layer that does not have any parameters
of its own. This should look familiar if you recall our introduction to
modules in :numref:`sec_model_construction`. The following
``CenteredLayer`` class simply subtracts the mean from its input. To
build it, we simply need to inherit from the base layer class and
implement the forward propagation function.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class CenteredLayer(nn.Module):
def __init__(self):
super().__init__()
def forward(self, X):
return X - X.mean()
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class CenteredLayer(nn.Block):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def forward(self, X):
return X - X.mean()
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class CenteredLayer(nn.Module):
def __call__(self, X):
return X - X.mean()
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class CenteredLayer(tf.keras.Model):
def __init__(self):
super().__init__()
def call(self, X):
return X - tf.reduce_mean(X)
.. raw:: html
.. raw:: html
Let’s verify that our layer works as intended by feeding some data
through it.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
layer = CenteredLayer()
layer(torch.tensor([1.0, 2, 3, 4, 5]))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor([-2., -1., 0., 1., 2.])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
layer = CenteredLayer()
layer(np.array([1.0, 2, 3, 4, 5]))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
[21:49:18] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array([-2., -1., 0., 1., 2.])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
layer = CenteredLayer()
layer(jnp.array([1.0, 2, 3, 4, 5]))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Array([-2., -1., 0., 1., 2.], dtype=float32)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
layer = CenteredLayer()
layer(tf.constant([1.0, 2, 3, 4, 5]))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
We can now incorporate our layer as a component in constructing more
complex models.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net = nn.Sequential(nn.LazyLinear(128), CenteredLayer())
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net = nn.Sequential()
net.add(nn.Dense(128), CenteredLayer())
net.initialize()
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net = nn.Sequential([nn.Dense(128), CenteredLayer()])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net = tf.keras.Sequential([tf.keras.layers.Dense(128), CenteredLayer()])
.. raw:: html
.. raw:: html
As an extra sanity check, we can send random data through the network
and check that the mean is in fact 0. Because we are dealing with
floating point numbers, we may still see a very small nonzero number due
to quantization.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
Y = net(torch.rand(4, 8))
Y.mean()
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor(-6.5193e-09, grad_fn=)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
Y = net(np.random.rand(4, 8))
Y.mean()
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array(3.783498e-10)
.. raw:: html
.. raw:: html
Here we utilize the ``init_with_output`` method which returns both the
output of the network as well as the parameters. In this case we only
focus on the output.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
Y, _ = net.init_with_output(d2l.get_key(), jax.random.uniform(d2l.get_key(),
(4, 8)))
Y.mean()
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Array(-3.7252903e-09, dtype=float32)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
Y = net(tf.random.uniform((4, 8)))
tf.reduce_mean(Y)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
Layers with Parameters
----------------------
Now that we know how to define simple layers, let’s move on to defining
layers with parameters that can be adjusted through training. We can use
built-in functions to create parameters, which provide some basic
housekeeping functionality. In particular, they govern access,
initialization, sharing, saving, and loading model parameters. This way,
among other benefits, we will not need to write custom serialization
routines for every custom layer.
Now let’s implement our own version of the fully connected layer. Recall
that this layer requires two parameters, one to represent the weight and
the other for the bias. In this implementation, we bake in the ReLU
activation as a default. This layer requires two input arguments:
``in_units`` and ``units``, which denote the number of inputs and
outputs, respectively.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class MyLinear(nn.Module):
def __init__(self, in_units, units):
super().__init__()
self.weight = nn.Parameter(torch.randn(in_units, units))
self.bias = nn.Parameter(torch.randn(units,))
def forward(self, X):
linear = torch.matmul(X, self.weight.data) + self.bias.data
return F.relu(linear)
Next, we instantiate the ``MyLinear`` class and access its model
parameters.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
linear = MyLinear(5, 3)
linear.weight
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Parameter containing:
tensor([[ 0.4783, 0.4284, -0.0899],
[-0.6347, 0.2913, -0.0822],
[-0.4325, -0.1645, -0.3274],
[ 1.1898, 0.6482, -1.2384],
[-0.1479, 0.0264, -0.9597]], requires_grad=True)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class MyDense(nn.Block):
def __init__(self, units, in_units, **kwargs):
super().__init__(**kwargs)
self.weight = self.params.get('weight', shape=(in_units, units))
self.bias = self.params.get('bias', shape=(units,))
def forward(self, x):
linear = np.dot(x, self.weight.data(ctx=x.ctx)) + self.bias.data(
ctx=x.ctx)
return npx.relu(linear)
Next, we instantiate the ``MyDense`` class and access its model
parameters.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
dense = MyDense(units=3, in_units=5)
dense.params
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
mydense0_ (
Parameter mydense0_weight (shape=(5, 3), dtype=)
Parameter mydense0_bias (shape=(3,), dtype=)
)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class MyDense(nn.Module):
in_units: int
units: int
def setup(self):
self.weight = self.param('weight', nn.initializers.normal(stddev=1),
(self.in_units, self.units))
self.bias = self.param('bias', nn.initializers.zeros, self.units)
def __call__(self, X):
linear = jnp.matmul(X, self.weight) + self.bias
return nn.relu(linear)
Next, we instantiate the ``MyDense`` class and access its model
parameters.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
dense = MyDense(5, 3)
params = dense.init(d2l.get_key(), jnp.zeros((3, 5)))
params
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
FrozenDict({
params: {
weight: Array([[-0.23823419, -0.70915407, 0.72494346],
[ 0.2568525 , -0.20872341, -0.8993567 ],
[ 0.80883664, 0.16673394, 0.75610644],
[-0.35652584, 0.13841456, -1.0971175 ],
[ 0.3117082 , 1.2280334 , -1.0946037 ]], dtype=float32),
bias: Array([0., 0., 0.], dtype=float32),
},
})
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class MyDense(tf.keras.Model):
def __init__(self, units):
super().__init__()
self.units = units
def build(self, X_shape):
self.weight = self.add_weight(name='weight',
shape=[X_shape[-1], self.units],
initializer=tf.random_normal_initializer())
self.bias = self.add_weight(
name='bias', shape=[self.units],
initializer=tf.zeros_initializer())
def call(self, X):
linear = tf.matmul(X, self.weight) + self.bias
return tf.nn.relu(linear)
Next, we instantiate the ``MyDense`` class and access its model
parameters.
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
dense = MyDense(3)
dense(tf.random.uniform((2, 5)))
dense.get_weights()
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
[array([[-0.01007051, -0.05935554, 0.03142897],
[ 0.02453684, -0.01833588, -0.03096254],
[-0.09680572, -0.01736571, -0.00858052],
[-0.02245625, 0.02958351, -0.05780673],
[ 0.03997313, 0.01949595, -0.00150928]], dtype=float32),
array([0., 0., 0.], dtype=float32)]
.. raw:: html
.. raw:: html
We can directly carry out forward propagation calculations using custom
layers.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
linear(torch.rand(2, 5))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor([[0.0000, 0.9316, 0.0000],
[0.1808, 1.4208, 0.0000]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
dense.initialize()
dense(np.random.uniform(size=(2, 5)))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array([[0. , 0.01633355, 0. ],
[0. , 0.01581812, 0. ]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
dense.apply(params, jax.random.uniform(d2l.get_key(),
(2, 5)))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Array([[0.3850514 , 0. , 0.49188882],
[0.46509624, 0.26056105, 0. ]], dtype=float32)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
dense(tf.random.uniform((2, 5)))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
We can also construct models using custom layers. Once we have that we
can use it just like the built-in fully connected layer.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net = nn.Sequential(MyLinear(64, 8), MyLinear(8, 1))
net(torch.rand(2, 64))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor([[ 0.0000],
[13.0800]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net = nn.Sequential()
net.add(MyDense(8, in_units=64),
MyDense(1, in_units=8))
net.initialize()
net(np.random.uniform(size=(2, 64)))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array([[0.06508517],
[0.0615553 ]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net = nn.Sequential([MyDense(64, 8), MyDense(8, 1)])
Y, _ = net.init_with_output(d2l.get_key(), jax.random.uniform(d2l.get_key(),
(2, 64)))
Y
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Array([[0.],
[0.]], dtype=float32)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net = tf.keras.models.Sequential([MyDense(8), MyDense(1)])
net(tf.random.uniform((2, 64)))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
Summary
-------
We can design custom layers via the basic layer class. This allows us to
define flexible new layers that behave differently from any existing
layers in the library. Once defined, custom layers can be invoked in
arbitrary contexts and architectures. Layers can have local parameters,
which can be created through built-in functions.
Exercises
---------
1. Design a layer that takes an input and computes a tensor reduction,
i.e., it returns :math:`y_k = \sum_{i, j} W_{ijk} x_i x_j`.
2. Design a layer that returns the leading half of the Fourier
coefficients of the data.
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html