Skip to content

Commit cd5e114

Browse files
committed
optimize allocs
1 parent bad5a03 commit cd5e114

File tree

5 files changed

+40
-22
lines changed

5 files changed

+40
-22
lines changed

calc/ndarray.go

+18-1
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,11 @@ func (a NDArray) Slice(axis int, start int, end int) NDArray {
371371
outShape[axis] = end - start
372372

373373
arr := Zeros(outShape...)
374-
walkSlice(a.shape, outShape, axis, start, func(inIndex int, outIndex int) {
374+
return a.SliceInto(axis, start, end, arr)
375+
}
376+
377+
func (a NDArray) SliceInto(axis int, start int, end int, arr NDArray) NDArray {
378+
walkSlice(a.shape, arr.shape, axis, start, func(inIndex int, outIndex int) {
375379
arr.data[outIndex] = a.data[inIndex]
376380
})
377381
return arr
@@ -741,7 +745,10 @@ func (a NDArray) InverseConv2DInto(g NDArray, hAxis int, wAxis int, fAxis int, a
741745

742746
func (a NDArray) ReLU() NDArray {
743747
arr := Zeros(a.shape...)
748+
return a.ReLUInto(arr)
749+
}
744750

751+
func (a NDArray) ReLUInto(arr NDArray) NDArray {
745752
for i, v := range a.data {
746753
if v > 0. {
747754
arr.data[i] = v
@@ -753,7 +760,10 @@ func (a NDArray) ReLU() NDArray {
753760

754761
func (a NDArray) ReLUMask(m NDArray) NDArray {
755762
arr := Zeros(BroadcastShape(a.shape, m.shape)...)
763+
return a.ReLUMaskInto(m, arr)
764+
}
756765

766+
func (a NDArray) ReLUMaskInto(m NDArray, arr NDArray) NDArray {
757767
if ShapeEqual(a.shape, m.shape) {
758768
for i := range arr.data {
759769
if m.data[i] > 0 {
@@ -815,7 +825,10 @@ func (a NDArray) SliceRoot(start int, length int) NDArray {
815825

816826
func (a NDArray) Normalize(axis int) NDArray {
817827
arr := Zeros(a.shape...)
828+
return a.NormalizeInto(axis, arr)
829+
}
818830

831+
func (a NDArray) NormalizeInto(axis int, arr NDArray) NDArray {
819832
aggrShape := make([]int, len(a.shape))
820833
for i := range aggrShape {
821834
if i == axis {
@@ -856,6 +869,10 @@ func (a NDArray) Normalize(axis int) NDArray {
856869

857870
func (a NDArray) InverseNormalize(g NDArray, axis int) NDArray {
858871
arr := Zeros(a.shape...)
872+
return a.InverseNormalizeInto(g, axis, arr)
873+
}
874+
875+
func (a NDArray) InverseNormalizeInto(g NDArray, axis int, arr NDArray) NDArray {
859876
aggrShape := make([]int, len(a.shape))
860877
for i := range aggrShape {
861878
if i == axis {

model/optimizer.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ type SGDOptimizer struct {
1313
func (o *SGDOptimizer) UpdateWeights(weights []calc.NDArray, grads []calc.NDArray) {
1414
for i := range weights {
1515
w, g := weights[i], grads[i]
16-
weights[i] = w.Add(g.MulConstant(-o.LR))
16+
weights[i] = w.AddInto(g.MulConstant(-o.LR), w)
1717
}
1818
}
1919

@@ -37,9 +37,9 @@ func (o *SGDMomentumOptimizer) UpdateWeights(weights []calc.NDArray, grads []cal
3737

3838
o.moments[i] = o.moments[i].MulConstant(o.Momentum).Add(g.MulConstant(-o.LR))
3939
if o.Nesterov {
40-
weights[i] = w.Add(o.moments[i].MulConstant(o.Momentum).Add(g.MulConstant(-o.LR)))
40+
weights[i] = w.AddInto(o.moments[i].MulConstant(o.Momentum).Add(g.MulConstant(-o.LR)), w)
4141
} else {
42-
weights[i] = w.Add(o.moments[i])
42+
weights[i] = w.AddInto(o.moments[i], w)
4343
}
4444
}
4545
}

tensor/ops_cmp.go

+6-4
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ func (g *gradientVisitor) VisitEqualMask(t *EqualMaskTensor) {
133133

134134
func ReLU(t Tensor) Tensor {
135135
return &ReLUTensor{
136-
baseTensor: base(t.Shape(), 0, t),
136+
baseTensor: base(t.Shape(), 1, t),
137137
t: t,
138138
}
139139
}
@@ -147,7 +147,8 @@ func (t *ReLUTensor) Visit(v TensorVisitor) { v.VisitReLU(t) }
147147

148148
func (e *evaluationVisitor) VisitReLU(t *ReLUTensor) {
149149
v := e.value(t.t)
150-
e.values[t.ID()] = v.ReLU()
150+
o := t.values[0]
151+
e.values[t.ID()] = v.ReLUInto(o)
151152
}
152153

153154
func (g *gradientVisitor) VisitReLU(t *ReLUTensor) {
@@ -159,7 +160,7 @@ func (g *gradientVisitor) VisitReLU(t *ReLUTensor) {
159160
// Zeroes out all values in t where the corresponding value in m is negative
160161
func ReLUMask(t Tensor, m Tensor) Tensor {
161162
return &ReLUMaskTensor{
162-
baseTensor: base(t.Shape(), 0, t, m),
163+
baseTensor: base(t.Shape(), 1, t, m),
163164
t: t,
164165
m: m,
165166
}
@@ -176,7 +177,8 @@ func (t *ReLUMaskTensor) Visit(v TensorVisitor) { v.VisitReLUMask(t) }
176177
func (e *evaluationVisitor) VisitReLUMask(t *ReLUMaskTensor) {
177178
v := e.value(t.t)
178179
mv := e.value(t.m)
179-
e.values[t.ID()] = v.ReLUMask(mv)
180+
o := t.values[0]
181+
e.values[t.ID()] = v.ReLUMaskInto(mv, o)
180182
}
181183

182184
func (g *gradientVisitor) VisitReLUMask(t *ReLUMaskTensor) {

tensor/ops_math.go

+6-4
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ func (g *gradientVisitor) VisitExp(t *ExpTensor) {
241241

242242
func Normalize(t Tensor, axis int) Tensor {
243243
return &NormalizeTensor{
244-
baseTensor: base(t.Shape(), 0, t),
244+
baseTensor: base(t.Shape(), 1, t),
245245
t: t,
246246
axis: axis,
247247
}
@@ -257,7 +257,8 @@ func (t *NormalizeTensor) Visit(v TensorVisitor) { v.VisitNormalize(t) }
257257

258258
func (e *evaluationVisitor) VisitNormalize(t *NormalizeTensor) {
259259
v := e.value(t.t)
260-
e.values[t.ID()] = v.Normalize(t.axis)
260+
o := t.values[0]
261+
e.values[t.ID()] = v.NormalizeInto(t.axis, o)
261262
}
262263

263264
func (g *gradientVisitor) VisitNormalize(t *NormalizeTensor) {
@@ -268,7 +269,7 @@ func (g *gradientVisitor) VisitNormalize(t *NormalizeTensor) {
268269

269270
func InverseNormalize(t Tensor, g Tensor, axis int) Tensor {
270271
return &InverseNormalizeTensor{
271-
baseTensor: base(t.Shape(), 0, t, g),
272+
baseTensor: base(t.Shape(), 1, t, g),
272273
t: t,
273274
g: g,
274275
axis: axis,
@@ -286,7 +287,8 @@ func (t *InverseNormalizeTensor) Visit(v TensorVisitor) { v.VisitInverseNormaliz
286287

287288
func (e *evaluationVisitor) VisitInverseNormalize(t *InverseNormalizeTensor) {
288289
v, g := e.value(t.t), e.value(t.g)
289-
e.values[t.ID()] = v.InverseNormalize(g, t.axis)
290+
o := t.values[0]
291+
e.values[t.ID()] = v.InverseNormalizeInto(g, t.axis, o)
290292
}
291293

292294
func (g *gradientVisitor) VisitInverseNormalize(t *InverseNormalizeTensor) {

tensor/ops_shape.go

+7-10
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
package tensor
22

3-
import (
4-
"github.com/tsholmes/go-dl/calc"
5-
)
6-
73
func Concat(axis int, as ...Tensor) Tensor {
84
return &ConcatTensor{
95
baseTensor: base(concat(axis, as...), 0, as...),
@@ -42,7 +38,7 @@ func (g *gradientVisitor) VisitConcat(t *ConcatTensor) {
4238

4339
func Slice(t Tensor, axis int, start int, end int) Tensor {
4440
return &SliceTensor{
45-
baseTensor: base(resize(t, axis, end-start), 0, t),
41+
baseTensor: base(resize(t, axis, end-start), 1, t),
4642
t: t,
4743
axis: axis,
4844
start: start,
@@ -62,7 +58,8 @@ func (t *SliceTensor) Visit(v TensorVisitor) { v.VisitSlice(t) }
6258

6359
func (e *evaluationVisitor) VisitSlice(t *SliceTensor) {
6460
v := e.value(t.t)
65-
e.values[t.ID()] = v.Slice(t.axis, t.start, t.end)
61+
o := t.values[0]
62+
e.values[t.ID()] = v.SliceInto(t.axis, t.start, t.end, o)
6663
}
6764

6865
func (g *gradientVisitor) VisitSlice(t *SliceTensor) {
@@ -73,7 +70,7 @@ func (g *gradientVisitor) VisitSlice(t *SliceTensor) {
7370

7471
func Unslice(t Tensor, axis int, size int, offset int) Tensor {
7572
return &UnsliceTensor{
76-
baseTensor: base(resize(t, axis, size), 0, t),
73+
baseTensor: base(resize(t, axis, size), 1, t),
7774
t: t,
7875
axis: axis,
7976
size: size,
@@ -93,9 +90,9 @@ func (t *UnsliceTensor) Visit(v TensorVisitor) { v.VisitUnslice(t) }
9390

9491
func (e *evaluationVisitor) VisitUnslice(t *UnsliceTensor) {
9592
v := e.value(t.t)
96-
v2 := calc.Zeros(t.Shape()...)
97-
v2.SetSlice(v, t.axis, t.offset)
98-
e.values[t.ID()] = v2
93+
o := t.values[0]
94+
o.SetSlice(v, t.axis, t.offset)
95+
e.values[t.ID()] = o
9996
}
10097

10198
func (g *gradientVisitor) VisitUnslice(t *UnsliceTensor) {

0 commit comments

Comments
 (0)