Skip to content

Commit c6d079e

Browse files
committed
update with bpt
1 parent 93cc6f7 commit c6d079e

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed

doc/src/week7/Latexfiles/.DS_Store

6 KB
Binary file not shown.

doc/src/week7/Latexfiles/bpt.tex

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
\documentclass{article}
2+
\usepackage{amsmath}
3+
\usepackage{amssymb}
4+
\usepackage{physics}
5+
6+
\begin{document}
7+
8+
\section*{Backpropagation Through Time (BPTT) for Recurrent Neural Networks}
9+
10+
Backpropagation Through Time (BPTT) is an extension of the backpropagation algorithm used to train Recurrent Neural Networks (RNNs). Unlike feedforward neural networks, RNNs have connections that form cycles, allowing them to maintain a "memory" of previous inputs. This makes BPTT more complex because the gradients must be propagated not only through layers but also through time steps.
11+
12+
\subsection*{RNN Structure}
13+
14+
Consider an RNN with the following structure:
15+
\begin{itemize}
16+
\item Input at time step \( t \): \( \mathbf{x}_t \)
17+
\item Hidden state at time step \( t \): \( \mathbf{h}_t \)
18+
\item Output at time step \( t \): \( \mathbf{y}_t \)
19+
\item Weight matrices: \( \mathbf{W}_h \) (hidden-to-hidden), \( \mathbf{W}_x \) (input-to-hidden), \( \mathbf{W}_y \) (hidden-to-output)
20+
\item Bias vectors: \( \mathbf{b}_h \) (hidden), \( \mathbf{b}_y \) (output)
21+
\end{itemize}
22+
23+
The hidden state and output are computed as:
24+
\[
25+
\mathbf{h}_t = \sigma(\mathbf{W}_h \mathbf{h}_{t-1} + \mathbf{W}_x \mathbf{x}_t + \mathbf{b}_h),
26+
\]
27+
\[
28+
\mathbf{y}_t = \mathbf{W}_y \mathbf{h}_t + \mathbf{b}_y,
29+
\]
30+
where \( \sigma(\cdot) \) is the activation function (e.g., tanh or ReLU).
31+
32+
\subsection*{Loss Function}
33+
34+
The loss function \( L \) measures the difference between the predicted output \( \mathbf{y}_t \) and the true output \( \mathbf{\hat{y}}_t \) over all time steps \( t = 1 \) to \( T \):
35+
\[
36+
L = \sum_{t=1}^T L_t(\mathbf{y}_t, \mathbf{\hat{y}}_t),
37+
\]
38+
where \( L_t \) is the loss at time step \( t \) (e.g., mean squared error or cross-entropy).
39+
40+
\subsection*{Backpropagation Through Time (BPTT)}
41+
42+
The goal of BPTT is to compute the gradients of the loss \( L \) with respect to the parameters \( \mathbf{W}_h \), \( \mathbf{W}_x \), \( \mathbf{W}_y \), \( \mathbf{b}_h \), and \( \mathbf{b}_y \). These gradients are used to update the parameters via gradient descent.
43+
44+
\subsubsection*{Gradient of the Loss with Respect to \( \mathbf{W}_y \) and \( \mathbf{b}_y \)}
45+
46+
The gradients of \( L \) with respect to \( \mathbf{W}_y \) and \( \mathbf{b}_y \) are straightforward since \( \mathbf{y}_t \) depends directly on these parameters:
47+
\[
48+
\frac{\partial L}{\partial \mathbf{W}_y} = \sum_{t=1}^T \frac{\partial L_t}{\partial \mathbf{y}_t} \frac{\partial \mathbf{y}_t}{\partial \mathbf{W}_y},
49+
\]
50+
\[
51+
\frac{\partial L}{\partial \mathbf{b}_y} = \sum_{t=1}^T \frac{\partial L_t}{\partial \mathbf{y}_t} \frac{\partial \mathbf{y}_t}{\partial \mathbf{b}_y}.
52+
\]
53+
Here, \( \frac{\partial \mathbf{y}_t}{\partial \mathbf{W}_y} = \mathbf{h}_t^\top \) and \( \frac{\partial \mathbf{y}_t}{\partial \mathbf{b}_y} = \mathbf{I} \).
54+
55+
\subsubsection*{Gradient of the Loss with Respect to \( \mathbf{W}_h \), \( \mathbf{W}_x \), and \( \mathbf{b}_h \)}
56+
57+
The gradients with respect to \( \mathbf{W}_h \), \( \mathbf{W}_x \), and \( \mathbf{b}_h \) are more complex because the hidden state \( \mathbf{h}_t \) depends on previous hidden states. We use the chain rule to propagate the error backward through time.
58+
59+
Let \( \mathbf{\delta}_t = \frac{\partial L}{\partial \mathbf{h}_t} \) be the error at time step \( t \). The error at time step \( t \) depends on the error at time step \( t+1 \) and the current output error:
60+
\[
61+
\mathbf{\delta}_t = \frac{\partial L_t}{\partial \mathbf{h}_t} + \mathbf{W}_h^\top (\mathbf{\delta}_{t+1} \odot \sigma'(\mathbf{h}_{t+1})),
62+
\]
63+
where \( \odot \) denotes element-wise multiplication and \( \sigma'(\cdot) \) is the derivative of the activation function.
64+
65+
The gradients with respect to \( \mathbf{W}_h \), \( \mathbf{W}_x \), and \( \mathbf{b}_h \) are then computed as:
66+
\[
67+
\frac{\partial L}{\partial \mathbf{W}_h} = \sum_{t=1}^T \mathbf{\delta}_t \mathbf{h}_{t-1}^\top,
68+
\]
69+
\[
70+
\frac{\partial L}{\partial \mathbf{W}_x} = \sum_{t=1}^T \mathbf{\delta}_t \mathbf{x}_t^\top,
71+
\]
72+
\[
73+
\frac{\partial L}{\partial \mathbf{b}_h} = \sum_{t=1}^T \mathbf{\delta}_t.
74+
\]
75+
76+
\subsubsection*{Summary of BPTT}
77+
78+
The BPTT algorithm can be summarized as follows:
79+
\begin{enumerate}
80+
\item Forward pass: Compute the hidden states \( \mathbf{h}_t \) and outputs \( \mathbf{y}_t \) for all time steps.
81+
\item Backward pass: Compute the errors \( \mathbf{\delta}_t \) starting from the last time step and propagate them backward through time.
82+
\item Compute the gradients with respect to all parameters using the errors \( \mathbf{\delta}_t \).
83+
\item Update the parameters using gradient descent.
84+
\end{enumerate}
85+
86+
\subsection*{Challenges of BPTT}
87+
88+
BPTT can suffer from the vanishing or exploding gradient problem, especially when dealing with long sequences. Techniques such as gradient clipping, using Long Short-Term Memory (LSTM) networks, or Gated Recurrent Units (GRUs) are often employed to mitigate these issues.
89+
90+
\end{document}

0 commit comments

Comments
 (0)