-
Notifications
You must be signed in to change notification settings - Fork 7
/
buildDatabase.py
66 lines (39 loc) · 1.21 KB
/
buildDatabase.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from databaseUtils import Database
from databaseUtils import getIntLabels
import time
from datetime import timedelta
import math
import random
import numpy as np
from numpy.random import seed
dataset = Database("vggfaces.json")
#dataset = Database("flowers17.json")
dataset.createHDF5DatabaseGroups()
dataset.openDatabase()
trainset = dataset.getTrainset()
devset = dataset.getDevset()
testset = dataset.getTestset()
classNames = dataset.getClassNames()
print(classNames)
clsLabelDict = {}
labelClsDict = {}
# assign each class name a unique integer
lbl=0
for clsname in classNames:
clsLabelDict[clsname]= lbl
labelClsDict[lbl] = clsname
lbl+=1
trainStrLabels = trainset["strLabels"][:, 0]
trainLabels= getIntLabels(clsLabelDict,trainStrLabels)
print("trainLabels.shape",trainLabels.shape)
numBatch = len(trainset["index"])
start = trainset["index"][0,0]
end = trainset["index"][0,1]
imgsBatch1 = trainset["images"][start:end,:]
print(imgsBatch1.shape)
print(devset["images"].shape)
print("number of classes", len(clsLabelDict))
print("numbBatch = ", numBatch , " approx num of imgs in each batch = ", end-start)
#print(type(trainLabels))
#print(trainLabels.shape)
dataset.closeDatabase()