-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathSolver.js
More file actions
49 lines (44 loc) · 1.29 KB
/
Solver.js
File metadata and controls
49 lines (44 loc) · 1.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
const { Matrix } = require('./Matrix')
class Solver {
constructor () {
this.decay_rate = 0.999
this.smooth_eps = 1e-8
this.step_cache = {}
}
step (model, stepSize, regc, clipval) {
// perform parameter update
var solverStats = {}
var numClipped = 0
var numTot = 0
for (var k in model) {
if (model.hasOwnProperty(k)) {
var m = model[k] // mat ref
if (!(k in this.step_cache)) {
this.step_cache[k] = new Matrix(m.rows, m.columns)
}
var s = this.step_cache[k]
for (var i = 0, n = m.w.length; i < n; i++) {
// rmsprop adaptive learning rate
var mdwi = m.dw[i]
s.w[i] = s.w[i] * this.decay_rate + (1.0 - this.decay_rate) * mdwi * mdwi
// gradient clip
if (mdwi > clipval) {
mdwi = clipval
numClipped++
}
if (mdwi < -clipval) {
mdwi = -clipval
numClipped++
}
numTot++
// update (and regularize)
m.w[i] += -stepSize * mdwi / Math.sqrt(s.w[i] + this.smooth_eps) - regc * m.w[i]
m.dw[i] = 0 // reset gradients for next iteration
}
}
}
solverStats['ratio_clipped'] = numClipped * 1.0 / numTot
return solverStats
}
}
module.exports = Solver