-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathindex.js
37 lines (31 loc) · 1.24 KB
/
index.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import { Model } from './model.js'
import { MnistDataSource } from './mnistDataSource.js'
import { ImageTransformer } from './ImageTransformer.js'
import { RandomDataSource } from './randomDataSource.js'
import { ArbitraryImageDataSource } from './arbitraryImageDataSource.js'
main()
async function main() {
// Instantiate the model
const model = new Model()
// Instantiate a data source
const dataSource = new MnistDataSource()
// const dataSource = new RandomDataSource()
// const dataSource = new ArbitraryImageDataSource()
// Instatiate the Image transformer
const transformer = new ImageTransformer()
// Check if there is a pretrained model. If it exists load it, or train the model
if (model.pretrainedModelExists()) {
await model.load()
} else {
// Create the layers
model.configure()
// and train
await model.train(await dataSource.getTrainingData(), 200)
}
// Test the model with testing data from the data source
const testData = await dataSource.getTestData()
const autoEncodedImages = model.autoencode(testData)
// save the images to disk
transformer.toImages(testData.arraySync(), 'org')
transformer.toImages(autoEncodedImages.arraySync())
}