-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_data.py
138 lines (115 loc) · 5.34 KB
/
generate_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import argparse
import os
import numpy as np
import time
import h5py
parser = argparse.ArgumentParser()
parser.add_argument('--input', '-i', type=str, default='data/band.h5', help='Location of dft output file')
parser.add_argument('--rots', '-r', type=str, default='data/rotations_210611.pickle', help='Location of molecule rotations')
parser.add_argument('--output', '-o', type=str, default='data/data.hdf5', help='Save location')
parser.add_argument('--typemap', '-t', action=argparse.BooleanOptionalAction, help='Create typemap instead of disks, spheres, heights')
parser.add_argument('--ncpu', type=int, default=8, help="Number of workers")
args = parser.parse_args()
if (args.ncpu > 1):
os.environ['OMP_NUM_THREADS'] = str(args.ncpu)
print('OMP_NUM_THREADS:', os.environ['OMP_NUM_THREADS'])
from ppstm.pyPPSTM import STMulator, STMgenerator
from ppafm.ml import AuxMap as aux
def pad_xyzs(xyzs, max_len):
xyzs_padded = [np.pad(xyz, ((0, max_len - len(xyz)), (0, 0))) for xyz in xyzs]
xyzs = np.stack(xyzs_padded, axis=0)
return xyzs
rotations = np.load(args.rots, allow_pickle=True)
set_sizes = dict()
with h5py.File(args.input, mode='r') as h5:
for key in rotations:
n_molecules = 0
d = rotations[key]
for id in d:
n_molecules += len(d[id])
set_sizes[key] = n_molecules
scan_window = ((2.0, 2.0, 10.0), (18.0, 18.0, 11.0))
scan_dim = (128, 128, 20)
tip_orb = 'CO'
batch_size = 30
above = 3.3
stmulator = STMulator.STMulator(return_afm=True,
timings=False,
scan_window=scan_window,
scan_dim=scan_dim,
tip_orb=tip_orb)
if args.typemap:
aux_maps = [
aux.MultiMapSpheresElements(
scan_dim = scan_dim[:2],
scan_window = [stmulator.scan_window[0][:2], stmulator.scan_window[1][:2]]),
]
else:
aux_maps = [
aux.AtomicDisks(
scan_dim = scan_dim[:2],
scan_window = [stmulator.scan_window[0][:2], stmulator.scan_window[1][:2]]),
aux.vdwSpheres(
scan_dim = scan_dim[:2],
scan_window = [stmulator.scan_window[0][:2], stmulator.scan_window[1][:2]]),
aux.HeightMap(scanner=stmulator.afmulator.scanner)
]
total_len = (set_sizes['train'] + set_sizes['val'] + set_sizes['test']) / batch_size
with h5py.File(args.output, 'w') as f:
start_time = time.time()
counter = 1
for mode in ['train', 'val', 'test']:
# Define generator
generator = STMgenerator.STMgenerator(stmulator=stmulator,
auxmaps=aux_maps,
h5_name=args.input,
mode=mode,
rotations_pkl=args.rots,
batch_size=batch_size,
dist_above=above,
timings=False)
# Calculate dataset shapes
n_mol = set_sizes[mode]
max_mol_len = 54
X_shape = (
n_mol, # Number of samples
2, # STM+AFM
stmulator.scan_dim[0], # x size
stmulator.scan_dim[1], # y size
stmulator.scan_dim[2] - stmulator.afmulator.scanner.nDimConvOut + 1 # z size
)
if args.typemap:
Y_shape = (n_mol, len(aux_maps)) + X_shape[2:4] + (3,)
else:
Y_shape = (n_mol, len(aux_maps)) + X_shape[2:4]
xyz_shape = (n_mol, max_mol_len, 5)
# Create new group in HDF5 file and add datasets to the group
g = f.create_group(mode)
X_h5 = g.create_dataset('X', shape=X_shape, chunks=(1,)+X_shape[1:], dtype='f')
if len(aux_maps) > 0:
Y_h5 = g.create_dataset('Y', shape=Y_shape, chunks=(1,)+Y_shape[1:], dtype='f')
xyz_h5 = g.create_dataset('xyz', shape=xyz_shape, chunks=(1,)+xyz_shape[1:], dtype='f')
# Generate data
ind = 0
for i, ((X_stm, X_afm), Y, xyz) in enumerate(generator):
# Write batch to the HDF5 file
n_batch = len(xyz)
X = np.concatenate([X_stm, X_afm], axis=1, dtype='f') # (30, 2, 128, 128, 20)
X_h5[ind:ind+n_batch] = X
if len(aux_maps) > 0:
Y = np.stack(Y, axis=1)
if args.typemap:
Y[:, 0] = Y[:, 0].transpose(0, 2, 1, 3) # multimap
else:
Y[:, 0] = Y[:, 0].transpose(0, 2, 1)
Y[:, 1] = Y[:, 1].transpose(0, 2, 1)
#Y[:, 2] = Y[:, 2].transpose(0, 2, 1) # HeightMap already in correct oriantation
Y_h5[ind:ind+n_batch] = Y
xyz_h5[ind:ind+n_batch] = pad_xyzs(xyz, max_mol_len)
ind += n_batch
# Print progress info
if i % 10 == 0:
eta = (time.time() - start_time)/counter * (total_len - counter)
print(f'Generated {mode} batch {i+1}/{len(generator)} - ETA: {eta:.1f}s')
counter += 1
print(f'Total time taken: {time.time() - start_time:.1f}s')