About the author:
Daniel is co-founder and CTO at Aqarios GmbH. He holds a M.Sc. in Computer Science from LMU Munich, and has published papers in reinforcement learning and quantum computing. He writes about technical topics in quantum computing and startups.
jrnl · home about list my ventures LinkedIn

# The product rule in RNNs

I find that the product rule is always forgotten in popular blog posts (see 1 and 2) discussing RNNs and backpropagation through time (BPTT). It is clear what is happening in those posts, but WHY exactly, in a mathematical sense, does the last output depend on all previous states? For this, let us look at the product rule3.

Consider the following unrolled RNN.

Assume the following:

ht=σ(Wht1+Uxt)h_t = \sigma(W * h_{t-1} + Ux_t)

yt=softmax(Vht)y_t = \mathrm{softmax}(V * h_t)

Using a mix of Leibniz' and Langrange's notation, I now derive:

h3W=σ(Wh2+Ux3)W=\frac{\partial h_3}{\partial W} = \frac{\partial \sigma(Wh_2 + Ux_3)}{\partial W} =

σ[Wh2+Ux3]=\sigma' * [Wh_2 + Ux_3]' = // Chain rule

σ[Wh2]=\sigma' * [Wh_2]' =

σ[Wσ(Wh1+Ux2)]=\sigma' * [W * \sigma(Wh_1 + Ux_2)]' =

σ(h2+Wh2)=\sigma' * (h_2 + W * h_2') = // Product rule

σ(h2+Wσ[Wh1+Ux2])=\sigma' * (h_2 + W * \sigma' * [Wh_1 + Ux_2]') =

σ(h2+Wσ(h1+Wσ(h0+Wh0)))=\sigma' * (h_2 + W * \sigma' * (h_1 + W * \sigma' * (h_0 + W * h_0'))) =

σh3h2+\sigma_{h_3}' * h_2 \mathbf{+} σh3Wσh2h1+\sigma_{h_3}' * W * \sigma_{h_2}' * h_1 \mathbf{+} σh3Wσh2Wσh1h0+\sigma_{h_3}' * W * \sigma_{h_2}' * W * \sigma_{h_1}' * h_0 \mathbf{+} σh3Wσh2Wσh1Wh0\sigma_{h_3}' * W * \sigma_{h_2}' * W * \sigma_{h_1}' * W * h_{0}'

Chain rule happens in line 1 to 2, product rule in line 4 to 5. Line 3 is simply explained by Ux not containing W (which we're deriving for). Now, it can be immediately seen that each summand of the last result keeps referencing further and further into the past.

Lastly, since this assumes the reader is familiar with the topic, a really nice further explanation of BPTT for the interested reader can be found here.

Published on