-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathprepareData.ts
52 lines (41 loc) · 1.46 KB
/
prepareData.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import * as tf from '@tensorflow/tfjs';
import {
IClasses,
} from './types';
const oneHot = (labelIndex: number, classLength: number) => tf.tidy(() => tf.oneHot(tf.tensor1d([labelIndex]).toInt(), classLength));
// const turnTensorArrayIntoTensor = (tensors: tf.Tensor[]) => tensors.reduce((data?: tf.Tensor, tensor: tf.Tensor) => tf.tidy(() => {
// if (data === undefined) {
// return tf.keep(tensor);
// }
// const newData = tf.keep(data.concat(tensor, 0));
// data.dispose();
// return newData;
// }), undefined);
export const addData = (tensors: tf.Tensor[]): tf.Tensor => {
const data = tf.keep(tensors[0]);
return tensors.slice(1).reduce((data: tf.Tensor, tensor: tf.Tensor) => tf.tidy(() => {
const newData = tf.keep(data.concat(tensor, 0));
data.dispose();
return newData;
}), data);
};
export const addLabels = (labels: string[], classes: IClasses): tf.Tensor2D | undefined => {
const classLength = Object.keys(classes).length;
if (classLength <= 1) {
throw new Error('You must provide more than 1 class for training');
}
return labels.reduce((data: tf.Tensor2D | undefined, label: string) => {
const labelIndex = classes[label];
const y = oneHot(labelIndex, classLength);
return tf.tidy(() => {
if (data === undefined) {
return tf.keep(y);
}
const old = data;
const ys = tf.keep(old.concat(y, 0));
old.dispose();
y.dispose();
return ys;
});
}, undefined);
};