-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathindex.test.ts
78 lines (71 loc) · 2.33 KB
/
index.test.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import * as tf from '@tensorflow/tfjs';
import getDefaultDownloadHandler from './getDefaultDownloadHandler';
jest.mock('./getDefaultDownloadHandler');
jest.genMockFromModule('@tensorflow/tfjs');
jest.mock('@tensorflow/tfjs', () => ({
train: {
adam: () => {},
},
model: ({
save: (handlerOrURL) => {
return handlerOrURL;
},
}),
loadModel: () => ({
getLayer: () => ({
output: null,
}),
inputs: [],
}),
}));
import MLClassifier from './index';
describe('ml-classifier', () => {
test('foo', () => {
expect('a').toEqual('a');
});
// describe('constructor', () => {
// test('that it persists params', async () => {
// const epochs = 123;
// const mlClassifier = new MLClassifier({
// epochs,
// });
// expect(mlClassifier.getParams().epochs).toEqual(epochs);
// });
// // test('that it calls init on construct', async () => {
// // MLClassifier.prototype.init = jest.fn(() => {});
// // const mlClassifier = new MLClassifier({ });
// // expect(mlClassifier.init).toHaveBeenCalled();
// // });
// });
// describe('save', () => {
// let mlClassifier;
// beforeEach(() => {
// mlClassifier = new MLClassifier();
// mlClassifier.loaded = jest.fn(() => {});
// });
// test('it waits for pretrained model as the first step', async () => {
// mlClassifier.model = tf.model;
// await mlClassifier.save();
// expect(mlClassifier.loaded).toHaveBeenCalled();
// });
// test('it throws if no model is set', async () => {
// const expectedError = new Error('You must call train prior to calling save');
// return mlClassifier.save().catch(err => {
// expect(err.message).toBe(expectedError.message);
// });
// });
// test('calls save with a handler if specified', async () => {
// const url = 'foobar';
// mlClassifier.model = tf.model;
// const result = await mlClassifier.save(url);
// expect(result).toEqual(url);
// });
// test('calls save with a default handler if none is specified', async () => {
// const def = 'def';
// getDefaultDownloadHandler.mockImplementationOnce(() => def);
// mlClassifier.model = tf.model;
// const result = await mlClassifier.save();
// expect(result).toEqual(def);
// });
// });
});