Skip to content

Commit aa80482

Browse files
committed
update
1 parent ef3312d commit aa80482

File tree

2 files changed

+20
-8
lines changed

2 files changed

+20
-8
lines changed

ann_class2/momentum.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,17 @@ def main():
6262
pYbatch, Z = forward(Xbatch, W1, b1, W2, b2)
6363
# print "first batch cost:", cost(pYbatch, Ybatch)
6464

65+
# gradients
66+
gW2 = derivative_w2(Z, Ybatch, pYbatch) + reg*W2
67+
gb2 = derivative_b2(Ybatch, pYbatch) + reg*b2
68+
gW1 = derivative_w1(Xbatch, Z, Ybatch, pYbatch, W2) + reg*W1
69+
gb1 = derivative_b1(Z, Ybatch, pYbatch, W2) + reg*b1
70+
6571
# updates
66-
W2 -= lr*(derivative_w2(Z, Ybatch, pYbatch) + reg*W2)
67-
b2 -= lr*(derivative_b2(Ybatch, pYbatch) + reg*b2)
68-
W1 -= lr*(derivative_w1(Xbatch, Z, Ybatch, pYbatch, W2) + reg*W1)
69-
b1 -= lr*(derivative_b1(Z, Ybatch, pYbatch, W2) + reg*b1)
72+
W2 -= lr*gW2
73+
b2 -= lr*gb2
74+
W1 -= lr*gW1
75+
b1 -= lr*gb1
7076

7177
if j % print_period == 0:
7278
pY, _ = forward(Xtest, W1, b1, W2, b2)

ann_class2/rmsprop.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,17 @@ def main():
4848
pYbatch, Z = forward(Xbatch, W1, b1, W2, b2)
4949
# print "first batch cost:", cost(pYbatch, Ybatch)
5050

51+
# gradients
52+
gW2 = derivative_w2(Z, Ybatch, pYbatch) + reg*W2
53+
gb2 = derivative_b2(Ybatch, pYbatch) + reg*b2
54+
gW1 = derivative_w1(Xbatch, Z, Ybatch, pYbatch, W2) + reg*W1
55+
gb1 = derivative_b1(Z, Ybatch, pYbatch, W2) + reg*b1
56+
5157
# updates
52-
W2 -= lr*(derivative_w2(Z, Ybatch, pYbatch) + reg*W2)
53-
b2 -= lr*(derivative_b2(Ybatch, pYbatch) + reg*b2)
54-
W1 -= lr*(derivative_w1(Xbatch, Z, Ybatch, pYbatch, W2) + reg*W1)
55-
b1 -= lr*(derivative_b1(Z, Ybatch, pYbatch, W2) + reg*b1)
58+
W2 -= lr*gW2
59+
b2 -= lr*gb2
60+
W1 -= lr*gW1
61+
b1 -= lr*gb1
5662

5763
if j % print_period == 0:
5864
# calculate just for LL

0 commit comments

Comments
 (0)