-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathloadPretrainedModel.test.ts
35 lines (32 loc) · 1.04 KB
/
loadPretrainedModel.test.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import * as tf from '@tensorflow/tfjs';
import loadPretrainedModel, {
PRETRAINED_MODELS_KEYS,
PRETRAINED_MODELS,
} from './loadPretrainedModel';
jest.genMockFromModule('@tensorflow/tfjs');
jest.mock('@tensorflow/tfjs', () => ({
model: (params) => ({
...params,
save: (handlerOrURL) => {
return handlerOrURL;
},
}),
loadModel: jest.fn((url) => ({
getLayer: () => ({
output: null,
}),
inputs: [],
})),
}));
describe('loadPretrainedModel', () => {
test('it throws an error if an invalid key is provided', async () => {
return loadPretrainedModel('foo').catch(err => {
expect(err.message).toEqual('You have supplied an invalid key for a pretrained model');
});
});
test('loads a pretrained model specified in the config with tf.loadModel', async () => {
const loadModel = jest.spyOn(tf, 'loadModel');
const model = await loadPretrainedModel(PRETRAINED_MODELS_KEYS.MOBILENET);
expect(loadModel).toHaveBeenCalledWith(PRETRAINED_MODELS[PRETRAINED_MODELS_KEYS.MOBILENET].url);
});
});