jrnl · home about list

# The product rule in RNNs

I find that the product rule is always forgotten in popular blog posts (see 1 and 2) discussing RNNs and backpropagation through time (BPTT). It is clear what is happening in those posts, but WHY exactly, in a mathematical sense, does the last output depend on all previous states? For this, let us look at the product rule3.

Consider the following unrolled RNN.

Assume the following:

ht=σ(Wht1+Uxt)h_t = \sigma(W * h_{t-1} + Ux_t)

yt=softmax(Vht)y_t = \mathrm{softmax}(V * h_t)

Using a mix of Leibniz' and Langrange's notation, I now derive:

h3W=σ(Wh2+Ux3)W=\frac{\partial h_3}{\partial W} = \frac{\partial \sigma(Wh_2 + Ux_3)}{\partial W} =

σ[Wh2+Ux3]=\sigma' * [Wh_2 + Ux_3]' = // Chain rule

σ[Wh2]=\sigma' * [Wh_2]' =

σ[Wσ(Wh1+Ux2)]=\sigma' * [W * \sigma(Wh_1 + Ux_2)]' =

σ(h2+Wh2)=\sigma' * (h_2 + W * h_2') = // Product rule

σ(h2+Wσ[Wh1+Ux2])=\sigma' * (h_2 + W * \sigma' * [Wh_1 + Ux_2]') =

σ(h2+Wσ(h1+Wσ(h0+Wh0)))=\sigma' * (h_2 + W * \sigma' * (h_1 + W * \sigma' * (h_0 + W * h_0'))) =

σh3h2+\sigma_{h_3}' * h_2 \mathbf{+} σh3Wσh2h1+\sigma_{h_3}' * W * \sigma_{h_2}' * h_1 \mathbf{+} σh3Wσh2Wσh1h0+\sigma_{h_3}' * W * \sigma_{h_2}' * W * \sigma_{h_1}' * h_0 \mathbf{+} σh3Wσh2Wσh1Wh0\sigma_{h_3}' * W * \sigma_{h_2}' * W * \sigma_{h_1}' * W * h_{0}'

Chain rule happens in line 1 to 2, product rule in line 4 to 5. Line 3 is simply explained by Ux not containing W (which we're deriving for). Now, it can be immediately seen that each summand of the last result keeps referencing further and further into the past.

Lastly, since this assumes the reader is familiar with the topic, a really nice further explanation of BPTT for the interested reader can be found here.

Published on