Skip to content

Commit 3610cfb

Browse files
alignPCA-2D (#1057)
Latest changes to alignPCA-2D, an alignment method applied for 2D classification using PCA and Euclidean distance. --------- Co-authored-by: alberto <alber.mena@gmail.com> Co-authored-by: albertmena <12760268+albertmena@users.noreply.github.com>
1 parent 1a442c5 commit 3610cfb

8 files changed

Lines changed: 1313 additions & 844 deletions

File tree

Lines changed: 385 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,385 @@
1+
#!/usr/bin/env python3
2+
"""/***************************************************************************
3+
*
4+
* Authors: Erney Ramirez-Aportela
5+
*
6+
***************************************************************************/
7+
"""
8+
9+
import mrcfile
10+
import argparse
11+
import sys, os
12+
import numpy as np
13+
import torch
14+
from xmippPyModules.classifyPcaFuntion.bnb_gpu import BnBgpu
15+
from xmippPyModules.classifyPcaFuntion.pca_gpu import PCAgpu
16+
from xmippPyModules.classifyPcaFuntion.assessment import evaluation
17+
18+
19+
def read_images(mrcfilename):
20+
21+
with mrcfile.open(mrcfilename, permissive=True) as f:
22+
emImages = f.data.astype(np.float32).copy()
23+
return emImages
24+
25+
26+
def save_images(data, voxel, outfilename):
27+
data = data.astype('float32')
28+
29+
if data.ndim == 2:
30+
data = np.expand_dims(data, axis=0)
31+
32+
with mrcfile.new(outfilename, overwrite=True) as mrc:
33+
mrc.set_data(data)
34+
mrc.voxel_size = (voxel, voxel, 1)
35+
mrc.update_header_stats()
36+
37+
38+
def flatGrid(freq_band, nBand):
39+
40+
dim, _ = freq_band.shape
41+
42+
fx = torch.fft.rfftfreq(dim, d=0.5/np.pi, device=cuda)
43+
fy = torch.fft.fftfreq(dim, d=0.5/np.pi, device=cuda)
44+
45+
grid_x, grid_y = torch.meshgrid(fx, fy, indexing='xy')
46+
del fx, fy
47+
48+
grid_flat = []
49+
50+
for n in range(nBand):
51+
mask = (freq_band == n)
52+
53+
fx_n = grid_x[mask]
54+
fy_n = grid_y[mask]
55+
56+
grid_flat.append(torch.stack([fx_n, fy_n], dim=0))
57+
del mask, fx_n, fy_n
58+
59+
del grid_x, grid_y
60+
61+
return grid_flat
62+
63+
64+
if __name__=="__main__":
65+
66+
parser = argparse.ArgumentParser(description="align images")
67+
parser.add_argument("-i", "--exp", help="input mrc file for experimental images)", required=True)
68+
parser.add_argument("-s", "--sampling", help="pixel size of the images", required=True)
69+
parser.add_argument("-c", "--classes", help="number of 2D classes", required=True)
70+
parser.add_argument("-r", "--ref", help="2D classes of external method")
71+
parser.add_argument("--mask", action="store_true", help="A Gaussian mask is used.")
72+
parser.add_argument("--sigma", type=float, help="value of sigma for the Gaussian mask. "
73+
"It is only used if the --mask option is applied.")
74+
parser.add_argument("-o", "--output", help="Root directory for the output files", required=True)
75+
parser.add_argument("-stExp", "--starExp", help="star file for images")
76+
#For training
77+
parser.add_argument("-t", "--training", help="number of image for training", required=True)
78+
parser.add_argument("-hr", "--highres", help="highest resolution to consider", required=True)
79+
parser.add_argument("-p", "--perc", help="PCA percentage (between 0-1)", required=True)
80+
81+
82+
args = parser.parse_args()
83+
84+
expFile = args.exp
85+
sampling = float(args.sampling)
86+
final_classes = int(args.classes)
87+
refImages = args.ref
88+
niter = 18
89+
mask = args.mask
90+
sigma = args.sigma
91+
output = args.output
92+
expStar = args.starExp
93+
Ntrain = int(args.training)
94+
highRes = float(args.highres)
95+
per_eig_value = float(args.perc)
96+
97+
torch.cuda.is_available()
98+
torch.cuda.current_device()
99+
cuda = torch.device('cuda:0')
100+
101+
#Determining GPU free memory
102+
gpu = torch.cuda.get_device_properties(0)
103+
total_memory = gpu.total_memory
104+
allocated_memory = torch.cuda.memory_allocated(0)
105+
free_memory = (total_memory - allocated_memory) / (1024 ** 3) # free memory GB
106+
print("Free memory %s" %free_memory)
107+
108+
#Read Images
109+
mmap = mrcfile.mmap(expFile, permissive=True)
110+
nExp = mmap.data.shape[0]
111+
dim = mmap.data.shape[1]
112+
113+
if mask and (sigma is None):
114+
sigma = dim/3
115+
116+
refClas = torch.zeros(nExp)
117+
dist = torch.zeros(nExp)
118+
translation_vector = torch.zeros(nExp, 2)
119+
angles_deg = np.zeros(nExp)
120+
121+
#PCA function
122+
nBand = 1
123+
pca = PCAgpu(nBand)
124+
125+
if refImages:
126+
maxRes = highRes
127+
else:
128+
maxRes = 16.0
129+
130+
freqBn, cvecs, coef = pca.calculatePCAbasis(mmap, Ntrain, nBand, dim, sampling, maxRes,
131+
minRes=530, per_eig=per_eig_value, batchPCA=True)
132+
133+
grid_flat = flatGrid(freqBn, nBand)
134+
135+
bnb = BnBgpu(nBand)
136+
137+
expBatchSize, expBatchSize2, numFirstBatch, initClBatch = bnb.determine_batches(free_memory, dim)
138+
print("batches: %s, %s, %s, %s" %(expBatchSize, expBatchSize2, numFirstBatch, initClBatch))
139+
140+
141+
#Initial classes with kmeans
142+
if refImages:
143+
initStep = -1
144+
clIm = read_images(refImages)
145+
cl = torch.from_numpy(clIm).float().to(cuda)
146+
else:
147+
initStep = int(min(numFirstBatch, np.ceil(nExp/expBatchSize)))
148+
149+
num_clusters_total = final_classes * 60 // 100
150+
max_classes_per_round = 50
151+
152+
clusters_per_round = []
153+
remaining = num_clusters_total
154+
155+
while remaining > 0:
156+
n = min(remaining, max_classes_per_round)
157+
clusters_per_round.append(n)
158+
remaining -= n
159+
160+
num_rounds = len(clusters_per_round)
161+
162+
block_size = (nExp + num_rounds - 1) // num_rounds
163+
all_averages = []
164+
165+
for r, k_round in enumerate(clusters_per_round):
166+
167+
if k_round == 0:
168+
continue
169+
170+
start = r * block_size
171+
end = min(start + block_size, nExp)
172+
block_len = end - start
173+
if block_len <= 0:
174+
break
175+
176+
batch_size = min(initClBatch, block_len)
177+
178+
indices = np.random.choice(block_len, size=batch_size, replace=False) + start
179+
180+
Im_zero = mmap.data[indices].astype(np.float32)
181+
Texp_zero = torch.as_tensor(Im_zero, device=cuda)
182+
Texp_zero *= bnb.create_circular_mask(Texp_zero)
183+
184+
pca_zero = bnb.create_batchExp(Texp_zero, freqBn, coef, cvecs)
185+
186+
cl_round, _ = bnb.kmeans_pytorch_for_averages(
187+
Texp_zero, pca_zero[0], cvecs, num_clusters=k_round
188+
)
189+
190+
all_averages.append(cl_round)
191+
192+
del Im_zero, Texp_zero, pca_zero, cl_round
193+
194+
cl = torch.cat(all_averages, dim=0)
195+
del all_averages
196+
197+
# file_cero = output+"_0.mrcs"
198+
# save_images(cl.cpu().detach().numpy(), sampling, file_cero)
199+
200+
201+
if refImages:
202+
num_batches = int(np.ceil(nExp / expBatchSize2))
203+
else:
204+
num_batches = min(int(np.ceil(nExp / expBatchSize)),
205+
int(numFirstBatch + np.ceil( (nExp - (numFirstBatch * expBatchSize))/(expBatchSize2) )))
206+
207+
208+
### Start initial cycles
209+
num_cycles = 1
210+
for cycles in range (num_cycles):
211+
batch_projExp_cpu = []
212+
endBatch = 0
213+
if cycles < num_cycles-1:
214+
num_batches_in_iter = initStep
215+
else:
216+
num_batches_in_iter = num_batches
217+
218+
### Start iterations per batches
219+
for i in range(num_batches_in_iter):
220+
221+
mode = False
222+
223+
if i < initStep:
224+
initBatch = i * expBatchSize
225+
endBatch = min( (i+1) * expBatchSize, nExp)
226+
else:
227+
initBatch = endBatch
228+
endBatch = min( endBatch + expBatchSize2, nExp)
229+
230+
expImages = mmap.data[initBatch:endBatch].astype(np.float32)
231+
Texp = torch.from_numpy(expImages).float().to(cuda)
232+
233+
if i < initStep:
234+
batch_projExp_cpu.append( bnb.batchExpToCpu(Texp, freqBn, coef, cvecs) )
235+
if i == initStep-1:
236+
mode = "create_classes"
237+
print(f"\nClassification mode", flush=True)
238+
print(f"Processing batch 0 - {endBatch}\n", flush=True)
239+
else:
240+
batch_projExp_cpu = bnb.create_batchExp(Texp, freqBn, coef, cvecs)
241+
mode = "align_classes"
242+
if i == initStep:
243+
print(f"\nAssignment mode", flush=True)
244+
print(f"Processing batch {initBatch} - {endBatch}", flush=True)
245+
del(Texp)
246+
247+
if mode:
248+
249+
#Initialization Transformation Matrix
250+
if mode == "create_classes":
251+
subset = endBatch
252+
else:
253+
subset = endBatch - initBatch
254+
255+
256+
tMatrix = torch.eye(2, 3, device = cuda).repeat(subset, 1, 1)
257+
258+
if mode == "align_classes":
259+
niter = 3
260+
261+
for iter in range(niter):
262+
if mode == "create_classes":
263+
print(f"[{iter + 1}/{niter}] Updating classes...", flush=True)
264+
265+
matches = torch.full((subset, 5), float("Inf"), device = cuda)
266+
267+
vectorRot, vectorshift = bnb.determine_ROTandSHIFT(iter, mode, dim)
268+
nShift = len(vectorshift)
269+
270+
for rot in vectorRot:
271+
272+
# print("---Precomputing the projections of the reference images---")
273+
batch_projRef = bnb.precalculate_projection(cl, freqBn, grid_flat,
274+
coef, cvecs, float(rot), vectorshift)
275+
276+
count = 0
277+
steps = initStep if mode == "create_classes" else 1
278+
279+
for i in range(steps):
280+
281+
if mode == "create_classes":
282+
init = i*expBatchSize
283+
batch_projExp = batch_projExp_cpu[count].to('cuda', non_blocking=True)
284+
else:
285+
init = 0
286+
batch_projExp = batch_projExp_cpu
287+
288+
matches = bnb.match_batch(batch_projExp, batch_projRef, init, matches, rot, nShift)
289+
del(batch_projExp)
290+
count+=1
291+
del(batch_projRef)
292+
293+
if mode == "create_classes":
294+
res_map = {4: 14, 7: 12, 10: 10, 13: highRes}
295+
296+
if iter in res_map:
297+
del (freqBn, coef, grid_flat, cvecs)
298+
299+
maxRes = max(res_map[iter], highRes)
300+
301+
freqBn, cvecs, coef = pca.calculatePCAbasis(
302+
mmap, Ntrain, nBand, dim, sampling, maxRes,
303+
minRes=530, per_eig=per_eig_value, batchPCA=True
304+
)
305+
grid_flat = flatGrid(freqBn, nBand)
306+
307+
#update classes
308+
classes = len(cl)
309+
310+
if mode == "create_classes":
311+
cl, tMatrix, batch_projExp_cpu = bnb.create_classes(
312+
mmap, tMatrix, iter, subset, expBatchSize, matches, vectorshift,
313+
classes, final_classes, freqBn, coef, cvecs, mask, sigma, sampling, cycles)
314+
315+
else:
316+
torch.cuda.empty_cache()
317+
cl, tMatrix, batch_projExp_cpu = bnb.align_particles_to_classes(expImages,
318+
cl, tMatrix, iter, subset, matches, vectorshift, classes,
319+
freqBn, coef, cvecs, mask, sigma, sampling)
320+
321+
322+
# save classes
323+
# file = output+"_%s_%s_%s.mrcs"%(initBatch,iter+1,cycles)
324+
# save_images(cl.cpu().detach().numpy(), sampling, file)
325+
326+
327+
if cycles == num_cycles-1 and mode == "create_classes" and iter == niter-1:
328+
329+
refClas[:endBatch] = matches[:, 1]
330+
dist[:endBatch] = matches[:, 2].cpu()
331+
332+
#Applying TMT(inv).
333+
#This is done because the rotation is performed from the center of the image.
334+
initial_shift = torch.tensor([[1.0, 0.0, -dim/2],
335+
[0.0, 1.0, -dim/2],
336+
[0.0, 0.0, 1.0]], device = tMatrix.device)
337+
initial_shift = initial_shift.unsqueeze(0).expand(tMatrix.size(0), -1, -1)
338+
339+
tMatrix = torch.cat((tMatrix, torch.zeros((tMatrix.size(0), 1, 3), device=tMatrix.device)), dim=1)
340+
tMatrix[:, 2, 2] = 1.0
341+
tMatrix = torch.matmul(initial_shift, tMatrix)
342+
tMatrix = torch.matmul(tMatrix, torch.inverse(initial_shift))
343+
344+
#extract final angular and shift transformations
345+
rotation_matrix = tMatrix[:, :2, :2]
346+
translation_vector[:endBatch] = tMatrix[:, :2, 2]
347+
angles_rad = torch.atan2(rotation_matrix[:, 1, 0], rotation_matrix[:, 0, 0])
348+
angles_deg[:endBatch] = np.degrees(angles_rad.cpu().numpy())
349+
350+
elif mode == "align_classes" and iter == 2:
351+
352+
refClas[initBatch:endBatch] = matches[:, 1]
353+
dist[initBatch:endBatch] = matches[:, 2].cpu()
354+
355+
initial_shift = torch.tensor([[1.0, 0.0, -dim/2],
356+
[0.0, 1.0, -dim/2],
357+
[0.0, 0.0, 1.0]], device = tMatrix.device)
358+
initial_shift = initial_shift.unsqueeze(0).expand(tMatrix.size(0), -1, -1)
359+
360+
tMatrix = torch.cat((tMatrix, torch.zeros((tMatrix.size(0), 1, 3), device=tMatrix.device)), dim=1)
361+
tMatrix[:, 2, 2] = 1.0
362+
tMatrix = torch.matmul(initial_shift, tMatrix)
363+
tMatrix = torch.matmul(tMatrix, torch.inverse(initial_shift))
364+
365+
rotation_matrix = tMatrix[:, :2, :2]
366+
translation_vector[initBatch:endBatch] = tMatrix[:, :2, 2]
367+
angles_rad = torch.atan2(rotation_matrix[:, 1, 0], rotation_matrix[:, 0, 0])
368+
angles_deg[initBatch:endBatch] = np.degrees(angles_rad.cpu().numpy())
369+
del(expImages)
370+
371+
372+
counts = torch.bincount(refClas.to(torch.int64), minlength=classes)
373+
374+
#save classes
375+
file_final = output+".mrcs"
376+
save_images(cl.cpu().detach().numpy(), sampling, file_final)
377+
378+
# print(counts.int())
379+
380+
assess = evaluation()
381+
assess.updateExpStar(expStar, refClas, -angles_deg, translation_vector, output, dist)
382+
assess.createClassesStar(classes, file_final, counts, output)
383+
384+
385+

0 commit comments

Comments
 (0)