Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
bob7783 committed Mar 4, 2020
1 parent ef3312d commit aa80482
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
14 changes: 10 additions & 4 deletions ann_class2/momentum.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,17 @@ def main():
pYbatch, Z = forward(Xbatch, W1, b1, W2, b2)
# print "first batch cost:", cost(pYbatch, Ybatch)

# gradients
gW2 = derivative_w2(Z, Ybatch, pYbatch) + reg*W2
gb2 = derivative_b2(Ybatch, pYbatch) + reg*b2
gW1 = derivative_w1(Xbatch, Z, Ybatch, pYbatch, W2) + reg*W1
gb1 = derivative_b1(Z, Ybatch, pYbatch, W2) + reg*b1

# updates
W2 -= lr*(derivative_w2(Z, Ybatch, pYbatch) + reg*W2)
b2 -= lr*(derivative_b2(Ybatch, pYbatch) + reg*b2)
W1 -= lr*(derivative_w1(Xbatch, Z, Ybatch, pYbatch, W2) + reg*W1)
b1 -= lr*(derivative_b1(Z, Ybatch, pYbatch, W2) + reg*b1)
W2 -= lr*gW2
b2 -= lr*gb2
W1 -= lr*gW1
b1 -= lr*gb1

if j % print_period == 0:
pY, _ = forward(Xtest, W1, b1, W2, b2)
Expand Down
14 changes: 10 additions & 4 deletions ann_class2/rmsprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,17 @@ def main():
pYbatch, Z = forward(Xbatch, W1, b1, W2, b2)
# print "first batch cost:", cost(pYbatch, Ybatch)

# gradients
gW2 = derivative_w2(Z, Ybatch, pYbatch) + reg*W2
gb2 = derivative_b2(Ybatch, pYbatch) + reg*b2
gW1 = derivative_w1(Xbatch, Z, Ybatch, pYbatch, W2) + reg*W1
gb1 = derivative_b1(Z, Ybatch, pYbatch, W2) + reg*b1

# updates
W2 -= lr*(derivative_w2(Z, Ybatch, pYbatch) + reg*W2)
b2 -= lr*(derivative_b2(Ybatch, pYbatch) + reg*b2)
W1 -= lr*(derivative_w1(Xbatch, Z, Ybatch, pYbatch, W2) + reg*W1)
b1 -= lr*(derivative_b1(Z, Ybatch, pYbatch, W2) + reg*b1)
W2 -= lr*gW2
b2 -= lr*gb2
W1 -= lr*gW1
b1 -= lr*gb1

if j % print_period == 0:
# calculate just for LL
Expand Down

0 comments on commit aa80482

Please sign in to comment.