We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent ef3312d commit aa80482Copy full SHA for aa80482
ann_class2/momentum.py
@@ -62,11 +62,17 @@ def main():
62
pYbatch, Z = forward(Xbatch, W1, b1, W2, b2)
63
# print "first batch cost:", cost(pYbatch, Ybatch)
64
65
+ # gradients
66
+ gW2 = derivative_w2(Z, Ybatch, pYbatch) + reg*W2
67
+ gb2 = derivative_b2(Ybatch, pYbatch) + reg*b2
68
+ gW1 = derivative_w1(Xbatch, Z, Ybatch, pYbatch, W2) + reg*W1
69
+ gb1 = derivative_b1(Z, Ybatch, pYbatch, W2) + reg*b1
70
+
71
# updates
- W2 -= lr*(derivative_w2(Z, Ybatch, pYbatch) + reg*W2)
- b2 -= lr*(derivative_b2(Ybatch, pYbatch) + reg*b2)
- W1 -= lr*(derivative_w1(Xbatch, Z, Ybatch, pYbatch, W2) + reg*W1)
- b1 -= lr*(derivative_b1(Z, Ybatch, pYbatch, W2) + reg*b1)
72
+ W2 -= lr*gW2
73
+ b2 -= lr*gb2
74
+ W1 -= lr*gW1
75
+ b1 -= lr*gb1
76
77
if j % print_period == 0:
78
pY, _ = forward(Xtest, W1, b1, W2, b2)
ann_class2/rmsprop.py
@@ -48,11 +48,17 @@ def main():
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# calculate just for LL
0 commit comments