RNN

The language model computes the probability of a sequence of previous words, \(P(word_1, word_2, \dots, word_t)\). The traditional language model is based on naive bayes model, the probability of a sequence of words is:

\[ P(word_1, word_2, \dots, word_{t}) = \prod_{i = 1}^{t} P(word_i|word_1, \dots, word_{i-1}). \]

The cons of traditional language models assumes independence among words (this is the naive part), which is not true in reality, it also requires a lot of RAM to compute the conditional probabilities for each word. One solution is to apply RNNs, which does not consider each input word independently, instead, RNNs can capture the information in previous sequence of words and evaluate their effects on the prediction on the current word (recurrent part).

The Model

Define

\[\mathbf{x}_t \in \mathbb{R}^{d}, \mathbf{s}_t \in \mathbb{R}^{D_s}, \mathbf{W} \in \mathbb{R}^{D_s \times D_s}, \mathbf{U} \in \mathbb{R}^{D_s \times d}, \mathbf{V} \in \mathbb{R}^{Voc \times D_s}, \mathbf{\hat{y}} \in \mathbb{R}^{Voc},\]

Converting the figure above into math expressions: \[ \begin{array}{rcl} \mathbf{s}_t &=& f\left(\mathbf{W}\mathbf{s}_{t-1} + \mathbf{U}\mathbf{x}_{t}\right)\\ \mathbf{\hat{y}}_t &=& softmax\left(\mathbf{V}\mathbf{s_t}\right), t = 1,2,\dots, T\\ \hat{p}(w_{t+1}&=& v_j|w_t, \dots,w_1) = \hat{y}_{t,j}, j = 1, 2, \dots, Voc \end{array} \]

where

  • \(\mathbf{s}_t\) is the hidden state at the time \(t\), it is a function (\(f\) can be tanh, sigmoid, ReLU) of the previous hidden state \(\mathbf{s}_{t-1}\) and word vector of the current input word \(\mathbf{x}_t\)
  • \(\hat{\mathbf{y}}_t\) is the predicted outcome (a vector of probabilities over all vocabulary) at time \(t\)
  • \(\hat{p}\) is the probability of the predicted word \(w_{t+1}\) (at time \(t+1\)) is equal to the \(j\)th word in the vocabulary \(v_j\) given a sequence of all previous input words, and \(\hat{p}(w_{t+1}= v_j|w_t, \dots,w_1)\) is equal to the \(j\)th element in \(\hat{\mathbf{y}}_t\).

At each time point, the \(\mathbf{s}_{t}\) is a function of \(\mathbf{s}_{t-1}\) and \(\mathbf{s}_{t-1}\) is also a function of \(\mathbf{s}_{t-2}\) etc., so the effects of all the previous input words will be taken into account when predicting the word at \(t+1\), this is the recurrent part in RNN.

The Cost Fucntion of RNN

The loss function at time \(t\) for the RNN model above is

\[E_{t}(\hat{y}_{t}, y_{t}) = -\sum_{j=1}^{Voc}y_{t,j}log(\hat{y}_{t,j}), j = 1, 2, \dots, Voc,\]

\[E(\hat{y}, y)= -\frac{1}{T}\sum_{t=1}^{T}E_{t}(\theta),\]

where \(y_{t,j}\) is the correct word at time \(t\) located at the \(j\)th element in \(\hat{\mathbf{y}}_t\), the loss function \(J\) tries to minimize the difference between \(y_{t,j}\) and \(\hat{y}_{t,j}\) over all word \(j\) and time \(t\).

Back Propagation of RNN

Our goal is to compute the gradient of total cost (the cost function is the loss \(E\) across all t, loss function is the \(E_t\) at time t) \(E(\hat{y}, y)\) wrt \(\mathbf{W}, \mathbf{U}\) and \(\mathbf{V}\). From the figure above, the gradient wrt \(\mathbf{V}\) only involves the values at current time \(t\), while the gradients wrt \(\mathbf{W}\) and \(\mathbf{U}\) involves the values at the current time \(t\) and at all the previous time.

Notice \(E(\hat{y}, y)\) is the sum of \(E_{t}(\hat{y}_{t}, y_{t})\), each of which involves \(\mathbf{W}, \mathbf{U}\) and \(\mathbf{V}\), so we have

\[\frac{\partial{E}}{\partial{\mathbf{W}}} = \sum_{t=1}^{T}\frac{\partial{E}_{t}}{\partial{\mathbf{\mathbf{W}}}}, \frac{\partial{E}}{\partial{\mathbf{U}}} = \sum_{t=1}^{T}\frac{\partial{E}_{t}}{\partial{\mathbf{\mathbf{U}}}}, \frac{\partial{E}}{\partial{\mathbf{V}}} = \sum_{t=1}^{T}\frac{\partial{E}_{t}}{\partial{\mathbf{\mathbf{V}}}}\] Define \(\mathbf{z}_t = \mathbf{Vs}_t\), \(\mathbf{z}_t \in \mathbb{R}^{Voc}\), recall \(\frac{\partial{E}_t}{\partial{\mathbf{z}_t}} = \hat{\mathbf{y}_t}-\mathbf{y}_t = \mathbf{\delta}_{1t}\), then

\[\frac{\partial{E}_{t}}{\partial{\mathbf{\mathbf{V}}}} = \mathbf{\delta}_{1t} \mathbf{s}_t', \mathbf{\delta}_{1t} \in \mathbb{R}^{Voc}, \mathbf{s}_t \in \mathbb{R}^{Ds}, \mathbf{V} \in \mathbb{R}^{Voc\times D_s}\], the gradient wrt \(\mathbf{V}\) at time \(t\) is only dependent on the values at the current time point, \(\hat{\mathbf{y}}_t, \mathbf{y}_t, \mathbf{s}_t\).

Assuming the \(f\) function is tanh function, \[\begin{array}{rcl}\frac{\partial{E}_{t}}{\partial{\mathbf{\mathbf{W}}}} &=& \frac{\partial{E}_t}{\partial{\mathbf{s}_t}}\frac{\partial{\mathbf{s}}_t}{\partial{\mathbf{h}_t}}\frac{\partial{\mathbf{h}}_t}{\partial{\mathbf{W}}}\\ &=& \frac{\partial{E}_t}{\partial{\mathbf{s}_t}} f'(\mathbf{h}_t)\left[\mathbf{s}_{t-1}+\mathbf{W}\frac{\partial{\mathbf{h}_t}}{\partial{\mathbf{W}}}\right]\\ &=& \frac{\partial{E}_t}{\partial{\mathbf{s}_t}} \left[f'\mathbf{s}_{t-1}+f'\mathbf{W} \frac{\partial{\mathbf{h}_t}}{\partial{\mathbf{W}}}\right]\\ &=& \frac{\partial{E}_t}{\partial{\mathbf{s}_t}} \frac{\partial{\mathbf{s}_t}}{\partial{\mathbf{W}}} + \frac{\partial{E}_t}{\partial{\mathbf{s}_t}} \frac{\partial{s}_t}{\partial{s}_{t-1}}\frac{\partial{s}_{t-1}}{\partial{\mathbf{W}}}\end{array},\] where \(\mathbf{h}_t = \mathbf{W}\mathbf{s}_{t-1} +\mathbf{U}\mathbf{x}_t\).

Since \(\frac{\partial{s}_{t-1}}{\partial{\mathbf{W}}} = \frac{\partial{\mathbf{s}_{t-1}}}{\partial{\mathbf{W}}} + \frac{\partial{s}_{t-1}}{\partial{s}_{t-2}}\frac{\partial{s}_{t-2}}{\partial{\mathbf{W}}}\), plug it into above equation, we have

\[\begin{array}{rcl} \frac{\partial{E}_{t}}{\partial{\mathbf{\mathbf{W}}}} &=& \frac{\partial{E}_t}{\partial{\mathbf{s}_t}} \frac{\partial{\mathbf{s}_t}}{\partial{\mathbf{W}}} + \frac{\partial{E}_t}{\partial{\mathbf{s}_t}} \frac{\partial{s}_t}{\partial{s}_{t-1}}\frac{\partial{s}_{t-1}}{\partial{\mathbf{W}}} + \frac{\partial{E}_t}{\partial{\mathbf{s}_t}} \frac{\partial{s}_t}{\partial{s}_{t-2}}\frac{\partial{s}_{t-2}}{\partial{\mathbf{W}}} + \dots + \frac{\partial{E}_t}{\partial{\mathbf{s}_t}} \frac{\partial{s}_t}{\partial{s}_{1}}\frac{\partial{s}_{1}}{\partial{\mathbf{W}}}\\ &=& \sum_{t=1}^{T}\sum_{k=1}^{t}\frac{\partial{E}_t}{\partial{\mathbf{s}_t}}\frac{\partial{s}_t}{\partial{s}_{t-k}}\frac{\partial{s}_{t-k}}{\partial{\mathbf{W}}}\end{array}, \mathbf{W} \in \mathbb{R}^{D_s\times D_s}\].

Similarly, the gradient of \(E_t\) wrt \(\mathbf{U}\) is

\[ \frac{\partial{E}_{t}}{\partial{\mathbf{\mathbf{U}}}} = \frac{\partial{E}_t}{\partial{\mathbf{s}_t}}\frac{\partial{\mathbf{s}}_t}{\partial{\mathbf{U}}} \] In \(\mathbf{s}_t\), \(\mathbf{U}\) is involved in \(\mathbf{s}_{t-1}\) and \(\mathbf{Ux}_t\), so

\[ \frac{\partial{E}_{t}}{\partial{\mathbf{\mathbf{U}}}} = \frac{\partial{E}_t}{\partial{\mathbf{s}_t}}\left(\frac{\partial{\mathbf{s}}_t}{\partial{\mathbf{U}}} + \frac{\partial{\mathbf{s}}_t}{\partial{\mathbf{s}_{t-1}}} \frac{\partial{\mathbf{s}}_{t-1}}{\partial{\mathbf{U}}}\right) \] since \(\frac{\partial{\mathbf{s}}_{t-1}}{\partial{\mathbf{U}}} = \left(\frac{\partial{\mathbf{s}}_{t-1}}{\partial{\mathbf{U}}} + \frac{\partial{\mathbf{s}}_{t-1}}{\partial{\mathbf{s}_{t-2}}} \frac{\partial{\mathbf{s}}_{t-2}}{\partial{\mathbf{U}}}\right)\), plug it into the above equation, we have

\[ \frac{\partial{E}_{t}}{\partial{\mathbf{\mathbf{U}}}} = \frac{\partial{E}_t}{\partial{\mathbf{s}_t}}\left(\frac{\partial{\mathbf{s}}_t}{\partial{\mathbf{U}}} + \frac{\partial{\mathbf{s}}_t}{\partial{\mathbf{s}_{t-1}}} \left(\frac{\partial{\mathbf{s}}_{t-1}}{\partial{\mathbf{U}}} + \frac{\partial{\mathbf{s}}_{t-1}}{\partial{\mathbf{s}_{t-2}}} \frac{\partial{\mathbf{s}}_{t-2}}{\partial{\mathbf{U}}}\right)\right) \]

continue this plug-in procedure until 0 time point, finally we have

\[ \frac{\partial{E}_{t}}{\partial{\mathbf{\mathbf{U}}}} = \sum_{t=1}^{T}\sum_{k=1}^{t}\frac{\partial{E}_t}{\partial{\mathbf{s}_t}} \frac{\partial{\mathbf{s}}_t}{\partial{\mathbf{s}_{t-k}}}\frac{\partial{s}_{t-k}}{\partial{\mathbf{U}}} \] Notice that the gradient wrt \(\mathbf{W}\) and \(\mathbf{U}\) at time \(t\) is dependent on both current and all previous values.

PRACTICE: Based on a sketch of network at a single timestep:

Consider feed-forward $$ Draw the ‘unrolled’ network for the previous 3 timesteps (t-1, t-2, t-3) and compute \(\frac{\partial{E}_{t}}{\partial{\mathbf{\mathbf{U}}}}\bigg\rvert_{t-1}=\), \(\frac{\partial{E}_{t}}{\partial{\mathbf{\mathbf{W}}}}\bigg\rvert_{t-1}=\), \(\frac{\partial{E}_{t}}{\partial{\mathbf{\mathbf{V}}}}=\), \(\frac{\partial{E}_{t}}{\partial{\mathbf{\mathbf{x}_t}}}=\), the complexity of back propagation wrt the model parameters across the entire T timesteps, \(\sum_{t=1}^T J^{(t)}\), is \(O\left[T(VD_h + dD_h + t(D_h^2 + dD_h))\right]\), the costy term is \(VD_h\), the dimension of the whole vocabulary, since \(V>>D_h\),

The Vanishing Gradient Problem

If we look deeper at the \(\frac{\partial{s}_t}{\partial{s}_{k}}\) part in \(\frac{\partial{E}_{t}}{\partial{\mathbf{\mathbf{W}}}} = \sum_{k=1}^{t}\frac{\partial{E}_t}{\partial{\mathbf{s}_t}}\frac{\partial{s}_t}{\partial{s}_{k}}\frac{\partial{s}_{k}}{\partial{\mathbf{W}}}\), \(\frac{\partial{\mathbf{s}}_t}{\partial{\mathbf{s}}_{k}} = \frac{\partial{\mathbf{s}}_t}{\partial{\mathbf{s}}_{t-1}}\frac{\partial{\mathb+ \mathbf{b}_i)\]

  • forget gate: controlling how much \(state_{t-1}\) influence \(state_t\) (how much past state should matter now), if \(\mathbf{f}_t\) is 1, \(state_{t=1}\) will be dirrectly copied at time \(t\), if \(\mathbf{f}_t\) is 0, \(state_{t_1}\) will be completely ignored at time \(t\),
  • \[\mathbf{f}_t = \sigma(\mathbf{W}_f\mathbf{x}_t + \mathbf{U}_f \mathbf{out}_{t-1} + \mathbf{b}_f)\]

    • output gate: controlling how much \(\mathbf{state}_t\) is exposed in \(\mathbf{out}_t\),

    \[\mathbf{o}_t = \sigma(\mathbf{W}_o\mathbf{x}_t + \mathbf{U}_o \mathbf{out}_{t-1} + \mathbf{b}_o)\] leading to:

    • internal state:

    \[\mathbf{state}_t = \mathbf{a}_t \odot \mathbf{i}_t + \mathbf{f}_t\odot \mathbf{state}_{t-1}\]

    • output at time \(t\):

    \[\mathbf{out}_t = tanh(\mathbf{state}_{t})\odot \mathbf{o}_t\]

    Due to the elementwise product \(\odot\) in \([\mathbf{f}_t\odot \mathbf{state}_{t-1}, \mathbf{a}_t \odot \mathbf{i}_t, tanh(\mathbf{state}_{t})\odot \mathbf{o}_t]\), in the back propagations wrt parameters (\(\mathbf{W}, \mathbf{U}, \mathbf{b}, \mathbf{x}\) in activation, forget, input and output gate), \(\mathbf{state}_{t-1}, \mathbf{f}_t, \mathbf{a}_t, \mathbf{i}_t, \mathbf{o}_t, \mathbf{state}_{t}\) will be directly copied (completely or partially depenending on the gates) to current time step rather than being involved in matrix multiplication , thus the vanishing gradient problems in RNN can be solved.

    The Back Propagation of LSTM

    Let \(\mathbf{z}_o, \mathbf{z}_i, \mathbf{z}_f, \mathbf{z}_a\) represents the \(\mathbf{W}\mathbf{x}_t + \mathbf{U} \mathbf{out}_{t-1} + \mathbf{b}\) in \(\mathbf{a}_t, \mathbf{f}_t,\mathbf{i}_t, \mathbf{o}_t\), respectively, and let \(\frac{\partial{E}}{\partial{\mathbf{out}_t}}=(\mathbf{out}_t-\mathbf{y}_t) + (\mathbf{out}_{t+1}-\mathbf{y}_{t+1})\frac{\partial{\mathbf{out}_{t+1}}}{\partial{\mathbf{out}_t}} = \mathbf{\delta}_t + \mathbf{\delta}_{t+1}\frac{\partial{\mathbf{out}_{t+1}}}{\partial{\mathbf{out}_t}}\) (the first term corresponds to the red solid line, the second term, the gradient at future time steps, corresponding to the red dotted line), then following the blue line in the flowchart,

    \[\frac{\partial{E}}{\partial{\mathbf{s}_t}} =\frac{\partial{E}}{\partial{\mathbf{out}_t}}\odot \mathbf{o}_t\odot[1-tanh^2(\mathbf{s}_t)] + \frac{\partial{E}}{\partial{\mathbf{s}_{t+1}}}\odot\mathbf{f}_{t+1},\] where the first term corresponding to the blue solid back propagation line, the second term comes from the blue dotted line.

    The gradient of \(\mathbf{s}_t\) involves the gradient of future state \(\mathbf{s}_{t+1}\), when \(t\) is the last time step, the second term in \(\frac{\partial{E}}{\partial{\mathbf{s}_t}}\) will be \(\mathbf{0}\).

    Next,

    \[\frac{\partial{E}}{\partial{\mathbf{z}_{a_t}}} = \frac{\partial{E}}{\partial{\mathbf{s}_t}}\odot\mathbf{i}_t \odot[1-tanh^2(\mathbf{z}_{a_t})],\] \[\frac{\partial{E}}{\partial{\mathbf{z}_{o_t}}}=\frac{\partial{E}}{\partial{\mathbf{out}_t}}\odot tanh(\mathbf{s}_t)\odot \sigma_o\odot(1-\sigma_o),\] \[\frac{\partial{E}}{\partial{\mathbf{z}_{f_t}}} = \frac{\partial{E}}{\partial{\mathbf{s}_t}}\odot\mathbf{s}_{t-1}\odot \sigma_f\odot(1-\sigma_f),\] \[\frac{\partial{E}}{\partial{\mathbf{z}_{i_t}}} = \frac{\partial{E}}{\partial{\mathbf{s}_t}}\odot\mathbf{a}_t \odot \sigma_i\odot(1-\sigma)_i.\]

    Notice, \(\frac{\partial{E}}{\partial{\mathbf{out}_t}}\) involves \(\mathbf{\delta}\) at time \(t\) and gradient from future time steps, \(\frac{\partial{E}}{\partial{\mathbf{s}_t}}\) also involves gradient wrt \(\mathbf{s}\) at time \(t\) and gradient at future time steps, because from the flowchart of LSTM, each gates at time \(t\) will be involved in all the future time steps.