|
| 1 | +#!/usr/bin/env python3 |
| 2 | +"""/*************************************************************************** |
| 3 | + * |
| 4 | + * Authors: Erney Ramirez-Aportela |
| 5 | + * |
| 6 | + ***************************************************************************/ |
| 7 | +""" |
| 8 | + |
| 9 | +import mrcfile |
| 10 | +import argparse |
| 11 | +import sys, os |
| 12 | +import numpy as np |
| 13 | +import torch |
| 14 | +from xmippPyModules.classifyPcaFuntion.bnb_gpu import BnBgpu |
| 15 | +from xmippPyModules.classifyPcaFuntion.pca_gpu import PCAgpu |
| 16 | +from xmippPyModules.classifyPcaFuntion.assessment import evaluation |
| 17 | + |
| 18 | + |
| 19 | +def read_images(mrcfilename): |
| 20 | + |
| 21 | + with mrcfile.open(mrcfilename, permissive=True) as f: |
| 22 | + emImages = f.data.astype(np.float32).copy() |
| 23 | + return emImages |
| 24 | + |
| 25 | + |
| 26 | +def save_images(data, voxel, outfilename): |
| 27 | + data = data.astype('float32') |
| 28 | + |
| 29 | + if data.ndim == 2: |
| 30 | + data = np.expand_dims(data, axis=0) |
| 31 | + |
| 32 | + with mrcfile.new(outfilename, overwrite=True) as mrc: |
| 33 | + mrc.set_data(data) |
| 34 | + mrc.voxel_size = (voxel, voxel, 1) |
| 35 | + mrc.update_header_stats() |
| 36 | + |
| 37 | + |
| 38 | +def flatGrid(freq_band, nBand): |
| 39 | + |
| 40 | + dim, _ = freq_band.shape |
| 41 | + |
| 42 | + fx = torch.fft.rfftfreq(dim, d=0.5/np.pi, device=cuda) |
| 43 | + fy = torch.fft.fftfreq(dim, d=0.5/np.pi, device=cuda) |
| 44 | + |
| 45 | + grid_x, grid_y = torch.meshgrid(fx, fy, indexing='xy') |
| 46 | + del fx, fy |
| 47 | + |
| 48 | + grid_flat = [] |
| 49 | + |
| 50 | + for n in range(nBand): |
| 51 | + mask = (freq_band == n) |
| 52 | + |
| 53 | + fx_n = grid_x[mask] |
| 54 | + fy_n = grid_y[mask] |
| 55 | + |
| 56 | + grid_flat.append(torch.stack([fx_n, fy_n], dim=0)) |
| 57 | + del mask, fx_n, fy_n |
| 58 | + |
| 59 | + del grid_x, grid_y |
| 60 | + |
| 61 | + return grid_flat |
| 62 | + |
| 63 | + |
| 64 | +if __name__=="__main__": |
| 65 | + |
| 66 | + parser = argparse.ArgumentParser(description="align images") |
| 67 | + parser.add_argument("-i", "--exp", help="input mrc file for experimental images)", required=True) |
| 68 | + parser.add_argument("-s", "--sampling", help="pixel size of the images", required=True) |
| 69 | + parser.add_argument("-c", "--classes", help="number of 2D classes", required=True) |
| 70 | + parser.add_argument("-r", "--ref", help="2D classes of external method") |
| 71 | + parser.add_argument("--mask", action="store_true", help="A Gaussian mask is used.") |
| 72 | + parser.add_argument("--sigma", type=float, help="value of sigma for the Gaussian mask. " |
| 73 | + "It is only used if the --mask option is applied.") |
| 74 | + parser.add_argument("-o", "--output", help="Root directory for the output files", required=True) |
| 75 | + parser.add_argument("-stExp", "--starExp", help="star file for images") |
| 76 | + #For training |
| 77 | + parser.add_argument("-t", "--training", help="number of image for training", required=True) |
| 78 | + parser.add_argument("-hr", "--highres", help="highest resolution to consider", required=True) |
| 79 | + parser.add_argument("-p", "--perc", help="PCA percentage (between 0-1)", required=True) |
| 80 | + |
| 81 | + |
| 82 | + args = parser.parse_args() |
| 83 | + |
| 84 | + expFile = args.exp |
| 85 | + sampling = float(args.sampling) |
| 86 | + final_classes = int(args.classes) |
| 87 | + refImages = args.ref |
| 88 | + niter = 18 |
| 89 | + mask = args.mask |
| 90 | + sigma = args.sigma |
| 91 | + output = args.output |
| 92 | + expStar = args.starExp |
| 93 | + Ntrain = int(args.training) |
| 94 | + highRes = float(args.highres) |
| 95 | + per_eig_value = float(args.perc) |
| 96 | + |
| 97 | + torch.cuda.is_available() |
| 98 | + torch.cuda.current_device() |
| 99 | + cuda = torch.device('cuda:0') |
| 100 | + |
| 101 | + #Determining GPU free memory |
| 102 | + gpu = torch.cuda.get_device_properties(0) |
| 103 | + total_memory = gpu.total_memory |
| 104 | + allocated_memory = torch.cuda.memory_allocated(0) |
| 105 | + free_memory = (total_memory - allocated_memory) / (1024 ** 3) # free memory GB |
| 106 | + print("Free memory %s" %free_memory) |
| 107 | + |
| 108 | + #Read Images |
| 109 | + mmap = mrcfile.mmap(expFile, permissive=True) |
| 110 | + nExp = mmap.data.shape[0] |
| 111 | + dim = mmap.data.shape[1] |
| 112 | + |
| 113 | + if mask and (sigma is None): |
| 114 | + sigma = dim/3 |
| 115 | + |
| 116 | + refClas = torch.zeros(nExp) |
| 117 | + dist = torch.zeros(nExp) |
| 118 | + translation_vector = torch.zeros(nExp, 2) |
| 119 | + angles_deg = np.zeros(nExp) |
| 120 | + |
| 121 | + #PCA function |
| 122 | + nBand = 1 |
| 123 | + pca = PCAgpu(nBand) |
| 124 | + |
| 125 | + if refImages: |
| 126 | + maxRes = highRes |
| 127 | + else: |
| 128 | + maxRes = 16.0 |
| 129 | + |
| 130 | + freqBn, cvecs, coef = pca.calculatePCAbasis(mmap, Ntrain, nBand, dim, sampling, maxRes, |
| 131 | + minRes=530, per_eig=per_eig_value, batchPCA=True) |
| 132 | + |
| 133 | + grid_flat = flatGrid(freqBn, nBand) |
| 134 | + |
| 135 | + bnb = BnBgpu(nBand) |
| 136 | + |
| 137 | + expBatchSize, expBatchSize2, numFirstBatch, initClBatch = bnb.determine_batches(free_memory, dim) |
| 138 | + print("batches: %s, %s, %s, %s" %(expBatchSize, expBatchSize2, numFirstBatch, initClBatch)) |
| 139 | + |
| 140 | + |
| 141 | + #Initial classes with kmeans |
| 142 | + if refImages: |
| 143 | + initStep = -1 |
| 144 | + clIm = read_images(refImages) |
| 145 | + cl = torch.from_numpy(clIm).float().to(cuda) |
| 146 | + else: |
| 147 | + initStep = int(min(numFirstBatch, np.ceil(nExp/expBatchSize))) |
| 148 | + |
| 149 | + num_clusters_total = final_classes * 60 // 100 |
| 150 | + max_classes_per_round = 50 |
| 151 | + |
| 152 | + clusters_per_round = [] |
| 153 | + remaining = num_clusters_total |
| 154 | + |
| 155 | + while remaining > 0: |
| 156 | + n = min(remaining, max_classes_per_round) |
| 157 | + clusters_per_round.append(n) |
| 158 | + remaining -= n |
| 159 | + |
| 160 | + num_rounds = len(clusters_per_round) |
| 161 | + |
| 162 | + block_size = (nExp + num_rounds - 1) // num_rounds |
| 163 | + all_averages = [] |
| 164 | + |
| 165 | + for r, k_round in enumerate(clusters_per_round): |
| 166 | + |
| 167 | + if k_round == 0: |
| 168 | + continue |
| 169 | + |
| 170 | + start = r * block_size |
| 171 | + end = min(start + block_size, nExp) |
| 172 | + block_len = end - start |
| 173 | + if block_len <= 0: |
| 174 | + break |
| 175 | + |
| 176 | + batch_size = min(initClBatch, block_len) |
| 177 | + |
| 178 | + indices = np.random.choice(block_len, size=batch_size, replace=False) + start |
| 179 | + |
| 180 | + Im_zero = mmap.data[indices].astype(np.float32) |
| 181 | + Texp_zero = torch.as_tensor(Im_zero, device=cuda) |
| 182 | + Texp_zero *= bnb.create_circular_mask(Texp_zero) |
| 183 | + |
| 184 | + pca_zero = bnb.create_batchExp(Texp_zero, freqBn, coef, cvecs) |
| 185 | + |
| 186 | + cl_round, _ = bnb.kmeans_pytorch_for_averages( |
| 187 | + Texp_zero, pca_zero[0], cvecs, num_clusters=k_round |
| 188 | + ) |
| 189 | + |
| 190 | + all_averages.append(cl_round) |
| 191 | + |
| 192 | + del Im_zero, Texp_zero, pca_zero, cl_round |
| 193 | + |
| 194 | + cl = torch.cat(all_averages, dim=0) |
| 195 | + del all_averages |
| 196 | + |
| 197 | + # file_cero = output+"_0.mrcs" |
| 198 | + # save_images(cl.cpu().detach().numpy(), sampling, file_cero) |
| 199 | + |
| 200 | + |
| 201 | + if refImages: |
| 202 | + num_batches = int(np.ceil(nExp / expBatchSize2)) |
| 203 | + else: |
| 204 | + num_batches = min(int(np.ceil(nExp / expBatchSize)), |
| 205 | + int(numFirstBatch + np.ceil( (nExp - (numFirstBatch * expBatchSize))/(expBatchSize2) ))) |
| 206 | + |
| 207 | + |
| 208 | + ### Start initial cycles |
| 209 | + num_cycles = 1 |
| 210 | + for cycles in range (num_cycles): |
| 211 | + batch_projExp_cpu = [] |
| 212 | + endBatch = 0 |
| 213 | + if cycles < num_cycles-1: |
| 214 | + num_batches_in_iter = initStep |
| 215 | + else: |
| 216 | + num_batches_in_iter = num_batches |
| 217 | + |
| 218 | + ### Start iterations per batches |
| 219 | + for i in range(num_batches_in_iter): |
| 220 | + |
| 221 | + mode = False |
| 222 | + |
| 223 | + if i < initStep: |
| 224 | + initBatch = i * expBatchSize |
| 225 | + endBatch = min( (i+1) * expBatchSize, nExp) |
| 226 | + else: |
| 227 | + initBatch = endBatch |
| 228 | + endBatch = min( endBatch + expBatchSize2, nExp) |
| 229 | + |
| 230 | + expImages = mmap.data[initBatch:endBatch].astype(np.float32) |
| 231 | + Texp = torch.from_numpy(expImages).float().to(cuda) |
| 232 | + |
| 233 | + if i < initStep: |
| 234 | + batch_projExp_cpu.append( bnb.batchExpToCpu(Texp, freqBn, coef, cvecs) ) |
| 235 | + if i == initStep-1: |
| 236 | + mode = "create_classes" |
| 237 | + print(f"\nClassification mode", flush=True) |
| 238 | + print(f"Processing batch 0 - {endBatch}\n", flush=True) |
| 239 | + else: |
| 240 | + batch_projExp_cpu = bnb.create_batchExp(Texp, freqBn, coef, cvecs) |
| 241 | + mode = "align_classes" |
| 242 | + if i == initStep: |
| 243 | + print(f"\nAssignment mode", flush=True) |
| 244 | + print(f"Processing batch {initBatch} - {endBatch}", flush=True) |
| 245 | + del(Texp) |
| 246 | + |
| 247 | + if mode: |
| 248 | + |
| 249 | + #Initialization Transformation Matrix |
| 250 | + if mode == "create_classes": |
| 251 | + subset = endBatch |
| 252 | + else: |
| 253 | + subset = endBatch - initBatch |
| 254 | + |
| 255 | + |
| 256 | + tMatrix = torch.eye(2, 3, device = cuda).repeat(subset, 1, 1) |
| 257 | + |
| 258 | + if mode == "align_classes": |
| 259 | + niter = 3 |
| 260 | + |
| 261 | + for iter in range(niter): |
| 262 | + if mode == "create_classes": |
| 263 | + print(f"[{iter + 1}/{niter}] Updating classes...", flush=True) |
| 264 | + |
| 265 | + matches = torch.full((subset, 5), float("Inf"), device = cuda) |
| 266 | + |
| 267 | + vectorRot, vectorshift = bnb.determine_ROTandSHIFT(iter, mode, dim) |
| 268 | + nShift = len(vectorshift) |
| 269 | + |
| 270 | + for rot in vectorRot: |
| 271 | + |
| 272 | + # print("---Precomputing the projections of the reference images---") |
| 273 | + batch_projRef = bnb.precalculate_projection(cl, freqBn, grid_flat, |
| 274 | + coef, cvecs, float(rot), vectorshift) |
| 275 | + |
| 276 | + count = 0 |
| 277 | + steps = initStep if mode == "create_classes" else 1 |
| 278 | + |
| 279 | + for i in range(steps): |
| 280 | + |
| 281 | + if mode == "create_classes": |
| 282 | + init = i*expBatchSize |
| 283 | + batch_projExp = batch_projExp_cpu[count].to('cuda', non_blocking=True) |
| 284 | + else: |
| 285 | + init = 0 |
| 286 | + batch_projExp = batch_projExp_cpu |
| 287 | + |
| 288 | + matches = bnb.match_batch(batch_projExp, batch_projRef, init, matches, rot, nShift) |
| 289 | + del(batch_projExp) |
| 290 | + count+=1 |
| 291 | + del(batch_projRef) |
| 292 | + |
| 293 | + if mode == "create_classes": |
| 294 | + res_map = {4: 14, 7: 12, 10: 10, 13: highRes} |
| 295 | + |
| 296 | + if iter in res_map: |
| 297 | + del (freqBn, coef, grid_flat, cvecs) |
| 298 | + |
| 299 | + maxRes = max(res_map[iter], highRes) |
| 300 | + |
| 301 | + freqBn, cvecs, coef = pca.calculatePCAbasis( |
| 302 | + mmap, Ntrain, nBand, dim, sampling, maxRes, |
| 303 | + minRes=530, per_eig=per_eig_value, batchPCA=True |
| 304 | + ) |
| 305 | + grid_flat = flatGrid(freqBn, nBand) |
| 306 | + |
| 307 | + #update classes |
| 308 | + classes = len(cl) |
| 309 | + |
| 310 | + if mode == "create_classes": |
| 311 | + cl, tMatrix, batch_projExp_cpu = bnb.create_classes( |
| 312 | + mmap, tMatrix, iter, subset, expBatchSize, matches, vectorshift, |
| 313 | + classes, final_classes, freqBn, coef, cvecs, mask, sigma, sampling, cycles) |
| 314 | + |
| 315 | + else: |
| 316 | + torch.cuda.empty_cache() |
| 317 | + cl, tMatrix, batch_projExp_cpu = bnb.align_particles_to_classes(expImages, |
| 318 | + cl, tMatrix, iter, subset, matches, vectorshift, classes, |
| 319 | + freqBn, coef, cvecs, mask, sigma, sampling) |
| 320 | + |
| 321 | + |
| 322 | + # save classes |
| 323 | + # file = output+"_%s_%s_%s.mrcs"%(initBatch,iter+1,cycles) |
| 324 | + # save_images(cl.cpu().detach().numpy(), sampling, file) |
| 325 | + |
| 326 | + |
| 327 | + if cycles == num_cycles-1 and mode == "create_classes" and iter == niter-1: |
| 328 | + |
| 329 | + refClas[:endBatch] = matches[:, 1] |
| 330 | + dist[:endBatch] = matches[:, 2].cpu() |
| 331 | + |
| 332 | + #Applying TMT(inv). |
| 333 | + #This is done because the rotation is performed from the center of the image. |
| 334 | + initial_shift = torch.tensor([[1.0, 0.0, -dim/2], |
| 335 | + [0.0, 1.0, -dim/2], |
| 336 | + [0.0, 0.0, 1.0]], device = tMatrix.device) |
| 337 | + initial_shift = initial_shift.unsqueeze(0).expand(tMatrix.size(0), -1, -1) |
| 338 | + |
| 339 | + tMatrix = torch.cat((tMatrix, torch.zeros((tMatrix.size(0), 1, 3), device=tMatrix.device)), dim=1) |
| 340 | + tMatrix[:, 2, 2] = 1.0 |
| 341 | + tMatrix = torch.matmul(initial_shift, tMatrix) |
| 342 | + tMatrix = torch.matmul(tMatrix, torch.inverse(initial_shift)) |
| 343 | + |
| 344 | + #extract final angular and shift transformations |
| 345 | + rotation_matrix = tMatrix[:, :2, :2] |
| 346 | + translation_vector[:endBatch] = tMatrix[:, :2, 2] |
| 347 | + angles_rad = torch.atan2(rotation_matrix[:, 1, 0], rotation_matrix[:, 0, 0]) |
| 348 | + angles_deg[:endBatch] = np.degrees(angles_rad.cpu().numpy()) |
| 349 | + |
| 350 | + elif mode == "align_classes" and iter == 2: |
| 351 | + |
| 352 | + refClas[initBatch:endBatch] = matches[:, 1] |
| 353 | + dist[initBatch:endBatch] = matches[:, 2].cpu() |
| 354 | + |
| 355 | + initial_shift = torch.tensor([[1.0, 0.0, -dim/2], |
| 356 | + [0.0, 1.0, -dim/2], |
| 357 | + [0.0, 0.0, 1.0]], device = tMatrix.device) |
| 358 | + initial_shift = initial_shift.unsqueeze(0).expand(tMatrix.size(0), -1, -1) |
| 359 | + |
| 360 | + tMatrix = torch.cat((tMatrix, torch.zeros((tMatrix.size(0), 1, 3), device=tMatrix.device)), dim=1) |
| 361 | + tMatrix[:, 2, 2] = 1.0 |
| 362 | + tMatrix = torch.matmul(initial_shift, tMatrix) |
| 363 | + tMatrix = torch.matmul(tMatrix, torch.inverse(initial_shift)) |
| 364 | + |
| 365 | + rotation_matrix = tMatrix[:, :2, :2] |
| 366 | + translation_vector[initBatch:endBatch] = tMatrix[:, :2, 2] |
| 367 | + angles_rad = torch.atan2(rotation_matrix[:, 1, 0], rotation_matrix[:, 0, 0]) |
| 368 | + angles_deg[initBatch:endBatch] = np.degrees(angles_rad.cpu().numpy()) |
| 369 | + del(expImages) |
| 370 | + |
| 371 | + |
| 372 | + counts = torch.bincount(refClas.to(torch.int64), minlength=classes) |
| 373 | + |
| 374 | + #save classes |
| 375 | + file_final = output+".mrcs" |
| 376 | + save_images(cl.cpu().detach().numpy(), sampling, file_final) |
| 377 | + |
| 378 | + # print(counts.int()) |
| 379 | + |
| 380 | + assess = evaluation() |
| 381 | + assess.updateExpStar(expStar, refClas, -angles_deg, translation_vector, output, dist) |
| 382 | + assess.createClassesStar(classes, file_final, counts, output) |
| 383 | + |
| 384 | + |
| 385 | + |
0 commit comments