-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy patharbitraryImageDataSource.js
44 lines (38 loc) · 1.3 KB
/
arbitraryImageDataSource.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import fs from 'fs'
import sharp from 'sharp'
import tf from '@tensorflow/tfjs-node'
/**
* Loads images from the images directory and processes them to fit the model
*/
export class ArbitraryImageDataSource {
constructor(countTraining = 1000, countTest = 10) {
const files = fs.readdirSync('images')
.filter(f => f.endsWith('.jpg') || f.endsWith('.jpeg') || f.endsWith('.png'))
.map(f => `images/${f}`)
this.trainingFiles = files.shuffle().slice(0, countTraining)
this.testFiles = files.shuffle().slice(0, countTest)
}
async getTrainingData() {
const data = await Promise.all(this.trainingFiles.map(f => this._processImageFile(f)))
return tf.tensor(data).div(255)
}
async getTestData() {
const data = await Promise.all(this.testFiles.map(f => this._processImageFile(f)))
return tf.tensor(data).div(255)
}
_processImageFile(filename) {
return sharp(filename)
.resize(28, 28, {
fit: 'cover'
})
.gamma()
.greyscale()
.raw()
.toBuffer()
}
}
Array.prototype.shuffle = function () {
return this.map((value) => ({ value, sort: Math.random() }))
.sort((a, b) => a.sort - b.sort)
.map(({ value }) => value)
}