Skip to article frontmatterSkip to article content
machine learning

Backpropagation

The following are my notes I prepared for the students as part of tutoring the Neural Networks class at Saarland University in fall 2023. Originally typset in LaTeX I decided that they are also a nice addition to my website. Most of the content is a more detailed discussion of the backpropagation section in Christopher Bishops excellent PRML Bishop, 2006 book with the addition of a few examples.

Consider a fully connected feed-forward neural network with one hidden layer with MM neurons, DD inputs and KK output units. A network of this is from depicted in Figure 1. A forward pass involves the following calculations.

a(1)=W(1)xam(1)=d=1DWmd(1)xdz=h(1)(a(1))zm=h(1)(am(1))a(2)=W(2)zak(2)=m=1MWkm(2)zmy^=h(2)(a(2))y^k=h(2)(ak(2))\begin{aligned} \mathbf{a}^{(1)} &= \boldsymbol{W}^{(1)}\mathbf{x} & a^{(1)}_m &= \sum_{d=1}^D \boldsymbol{W}^{(1)}_{md}x_d\\ \mathbf{z} &= h^{(1)} \left(\mathbf{a}^{(1)}\right) & z_m &= h^{(1)}\left(a^{(1)}_m\right)\\ \mathbf{a}^{(2)} &= \boldsymbol{W}^{(2)}\mathbf{z} & a^{(2)}_k &= \sum_{m=1}^M \boldsymbol{W}^{(2)}_{km}z_m\\ \mathbf{\hat{y}} &= h^{(2)} \left(\mathbf{a}^{(2)}\right) & \hat{y}_k &= h^{(2)}\left(a^{(2)}_k\right) \end{aligned}

Here xRD,a(1),zRM,a(2),y^RK\mathbf{x} \in \mathbb{R}^{D}, \mathbf{a}^{(1)}, \mathbf{z} \in \mathbb{R}^{M}, \mathbf{a}^{(2)}, \mathbf{\hat{y}} \in \mathbb{R}^{K} and W(1)RM×D,W(2)RK×M\boldsymbol{W}^{(1)} \in \mathbb{R}^{M \times D}, \boldsymbol{W}^{(2)} \in \mathbb{R}^{K \times M}. And h(1)h^{(1)} is the activation function after the first set of activations a(1)\mathbf{a}^{(1)} and h(2)h^{(2)} is the output activation function. Now we will consider a separable loss or error function J\mathrm{J}.

J({(xi,yi)}i=1N;W(1),W(2))=n=1NJn(xn,y^n;W(1),W(2))\mathrm{J}\left({\{(\mathbf{x}_i, \mathbf{y}_i)\}_{i=1}^N; \boldsymbol{W}^{(1)}, \boldsymbol{W}^{(2)}}\right) = \sum_{n=1}^N \mathrm{J}_n \left( \mathbf{x}_n, \mathbf{\hat{y}}_n ; \boldsymbol{W}^{(1)}, \boldsymbol{W}^{(2)} \right)

To minimize the loss we are interested in gradients of the weight matrices. From now on we will only consider the loss Jn\mathrm{J}_n of a single sample (xn,y^n)(\mathbf{x}_n, \mathbf{\hat{y}}_n). Let’s first consider the loss w.r.t. entry Wji(2)\boldsymbol{W}^{(2)}_{ji}. Because Wji(2)\boldsymbol{W}^{(2)}_{ji} influences the output y^\mathbf{\hat{y}} only through activation aj(2)a_j^{(2)}, by applying the chain rule we may also write:

JnWji(2)=Jnaj(2)aj(2)Wji(2)\begin{aligned} \frac{\partial \mathrm{J}_n}{\partial \boldsymbol{W}^{(2)}_{ji}} &= \frac{\partial \mathrm{J}_n}{\partial a_j^{(2)}} \frac{\partial a_j^{(2)}}{\partial \boldsymbol{W}^{(2)}_{ji}} \end{aligned}

We can further simplify by again applying the chain rule, this time to the first term:

=Jny^jy^jaj(2)aj(2)Wji(2)\begin{aligned} &= \frac{\partial \mathrm{J}_n}{\partial \hat{y}_j} \frac{\partial \hat{y}_j}{\partial a_j^{(2)}} \frac{\partial a_j^{(2)}}{\partial \boldsymbol{W}^{(2)}_{ji}} \end{aligned}

Here Jny^j\frac{\partial \mathrm{J}_n}{\partial \hat{y}_j} depends on the loss function. The derivatives of e.g. the squared or the cross-entropy error function read:

Jny^j={y^jynjwhen Jn=12(y^jynj)2ynjy^jwhen Jn=k=1Kynklog(y^k)\frac{\partial \mathrm{J}_n}{\partial \hat{y}_j} = \begin{cases} \hat{y}_j - y_{nj} & \text{when } \mathrm{J}_n = \frac{1}{2}\left( \hat{y}_j - y_{nj} \right)^2\\ -\frac{y_{nj}}{\hat{y}_j} & \text{when } \mathrm{J}_n = \sum_{k=1}^K y_{nk} \log (\hat{y}_k) \end{cases}

The second term y^jaj(2)\frac{\partial \hat{y}_j}{\partial a_j^{(2)}} is determined solely by the output activation function h(2)h^{(2)} and reduces to 1 if it is the identity function h(2)(a(2))=a(2)h^{(2)}\left(\mathbf{a^{(2)}} \right) = \mathbf{a^{(2)}} (as is usually the case in regression).
The remaining term is trivial to calculate:

aj(2)Wji(2)=Wji(2)(m=1MWjm(2)zm)=zi\begin{aligned} \frac{\partial a_j^{(2)}}{\partial \boldsymbol{W}^{(2)}_{ji}} = \frac{\partial }{\partial \boldsymbol{W}^{(2)}_{ji}} \left( \sum_{m=1}^M \boldsymbol{W}^{(2)}_{jm}z_m \right) = z_i \end{aligned}

Therefore the partial derivative for entry ijij of W(2)\boldsymbol{W}^{(2)} reads:

JnWji(2)=Jnaj(2)aj(2)Wji(2)=[Jny^jy^jaj(2)]zi=[Jny^jh(2)(aj(2))]δj(2)zi=δj(2)zi\begin{align*} \frac{\partial \mathrm{J}_n}{\partial \boldsymbol{W}^{(2)}_{ji}} &= \frac{\partial \mathrm{J}_n}{\partial a_j^{(2)}} \frac{\partial a_j^{(2)}}{\partial \boldsymbol{W}^{(2)}_{ji}} \\ &=\left[ \frac{\partial \mathrm{J}_n}{\partial \hat{y}_j} \frac{\partial \hat{y}_j}{\partial a_j^{(2)}} \right] \cdot z_i \\ &= \underbrace{\left[ \frac{\partial \mathrm{J}_n}{\partial \hat{y}_j} \cdot h^{(2)\prime} \left(a_j^{(2)} \right) \right]}_{\delta^{(2)}_j} \cdot z_i = \delta^{(2)}_j \cdot z_i \end{align*}

We denote the term with δj(2):=Jnaj(2)\delta^{(2)}_j := \frac{\partial \mathrm{J}_n}{\partial a_j^{(2)}} and call it the error term. The calculation of the partial derivatives of Wji(1)\boldsymbol{W}^{(1)}_{ji} is a little more involved because it requires careful considerations, which terms are influenced by Wji(1)\boldsymbol{W}^{(1)}_{ji} in later layers.
As before, Wji(1)\boldsymbol{W}^{(1)}_{ji} influences y^\mathbf{\hat{y}} only through aj(1)a^{(1)}_j, hence we may write:

JnWji(1)=Jnaj(1)aj(1)Wji(1)\frac{\partial \mathrm{J}_n}{\partial \boldsymbol{W}^{(1)}_{ji}} = \frac{\partial \mathrm{J}_n}{\partial a^{(1)}_j} \frac{\partial a^{(1)}_j}{\partial \boldsymbol{W}^{(1)}_{ji}}

To begin with, we focus on the first term. Notice that Jn\mathrm{J}_n depends on aj(1)a^{(1)}_j only through zjz_j:

Jnaj(1)=Jnzjzjaj(1)\begin{aligned} \frac{\partial \mathrm{J}_n}{\partial a^{(1)}_j} &= \frac{\partial \mathrm{J}_n}{\partial z_j} \frac{\partial z_j}{\partial a^{(1)}_j} \end{aligned}

This time however, zjz_j is involved in the calculation of all output units yky_k by ak(2)=m=1MWkm(2)zma^{(2)}_k = \sum_{m=1}^M \boldsymbol{W}^{(2)}_{km}z_m. Therefore, to express Jnzj\frac{\partial \mathrm{J}_n}{\partial z_j} in terms of ak(2)a^{(2)}_k, we need to sum over all k=1,,Kk=1,\dots, K:

=k=1KJnak(2)ak(2)zjzjaj(1)\begin{aligned} &= \sum_{k=1}^{K} \frac{\partial \mathrm{J}_n}{\partial a^{(2)}_k} \frac{\partial a^{(2)}_k}{\partial z_j} \frac{\partial z_j}{\partial a^{(1)}_j} \end{aligned}

By definition, we know Jnak(2)=δk(2)\frac{\partial \mathrm{J}_n}{\partial a^{(2)}_k} = \delta^{(2)}_k. It’s easy to verify that zjaj(1)=h(1)(aj(1))\frac{\partial z_j}{\partial a^{(1)}_j} = h^{(1)\prime}(a^{(1)}_j) and similarly by writing out the definition of ak(2)zj\frac{\partial a^{(2)}_k}{\partial z_j} it’s clear that the result of that term is Wkj(2)\boldsymbol{W}^{(2)}_{kj}. Therefore we obtain:

=k=1Kδk(2)Wkj(2)h(1)(aj(1))=h(1)(aj(1))k=1KWkj(2)δk(2)\begin{aligned} &= \sum_{k=1}^{K} \delta^{(2)}_k\boldsymbol{W}^{(2)}_{kj} h^{(1)\prime}\left(a^{(1)}_j\right)\\ &= h^{(1)\prime}\left(a^{(1)}_j\right) \cdot \sum_{k=1}^{K} \boldsymbol{W}^{(2)}_{kj} \delta^{(2)}_k \end{aligned}

Only the partial derivative aj(1)Wji(1)\frac{\partial a^{(1)}_j}{\partial \boldsymbol{W}^{(1)}_{ji}} is left, however the calculation is trivial:

aj(1)Wji(1)=Wji(1)(d=1DWjd(1)xd)=xd\frac{\partial a^{(1)}_j}{\partial \boldsymbol{W}^{(1)}_{ji}} = \frac{\partial }{\partial \boldsymbol{W}^{(1)}_{ji}} \left( \sum_{d=1}^D \boldsymbol{W}^{(1)}_{jd}x_d \right) = x_d

Consequently, every entry ijij of W(1)\boldsymbol{W}^{(1)} has partial derivative:

JnWji(1)=Jnaj(1)aj(1)Wji(1)=[k=1KJnak(2)ak(2)zjzjaj(1)]aj(1)Wji(1)=[k=1KJnak(2)ak(2)zjzjaj(1)]xd=[h(1)(aj(1))k=1KWkj(2)δk(2)]δj(1)xd=δj(1)xd\begin{align*} \frac{\partial \mathrm{J}_n}{\partial \boldsymbol{W}^{(1)}_{ji}} &= \frac{\partial \mathrm{J}_n}{\partial a^{(1)}_j} \frac{\partial a^{(1)}_j}{\partial \boldsymbol{W}^{(1)}_{ji}} \\ &=\left[ \sum_{k=1}^{K} \frac{\partial \mathrm{J}_n}{\partial a^{(2)}_k} \frac{\partial a^{(2)}_k}{\partial z_j} \frac{\partial z_j}{\partial a^{(1)}_j} \right] \frac{\partial a^{(1)}_j}{\partial \boldsymbol{W}^{(1)}_{ji}}\\ &=\left[ \sum_{k=1}^{K} \frac{\partial \mathrm{J}_n}{\partial a^{(2)}_k} \frac{\partial a^{(2)}_k}{\partial z_j} \frac{\partial z_j}{\partial a^{(1)}_j} \right] \cdot x_d\\ &= \underbrace{\left[ h^{(1)\prime}\left(a^{(1)}_j\right) \cdot \sum_{k=1}^{K} \boldsymbol{W}^{(2)}_{kj} \delta^{(2)}_k \right]}_{\delta^{(1)}_j} \cdot x_d = \delta^{(1)}_j \cdot x_d \end{align*}

Note, although we assumed a single hidden layer neural network, the partial derivatives for deeper layers take the same form as for W(1)\boldsymbol{W}^{(1)}, this can be seen by defining x=z(0)=h(0)(a(0))\mathbf{x} = \mathbf{z}^{(0)} = h^{(0)}\left(\mathbf{a}^{(0)}\right) as the output of the previous layer, with a(0)=W0z(1)\mathbf{a}^{(0)} = \boldsymbol{W}^{0}\mathbf{z}^{(-1)}.

Example feed-forward neural network with one hidden layer.

Figure 1:Example feed-forward neural network with one hidden layer.

References
  1. Bishop, C. M. (2006). Pattern Recognition and Machine Learning. Springer-Verlag New York. https://link.springer.com/book/9780387310732