Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dannce/engine/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1476,3 +1476,4 @@ def write_npy(uri, gen):
print(fname)
np.save(os.path.join(imdir, fname + ".npy"), bch[0][0][j].astype("uint8"))
np.save(os.path.join(griddir, fname + ".npy"), bch[0][1][j])

87 changes: 85 additions & 2 deletions dannce/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,11 +740,33 @@ def dannce_train(params: Dict):

if params["use_npy"]:
# Add all npy volume directories to list, to be used by generator
npydir = {}
dirnames = ["image_volumes", "grid_volumes", "targets"]
npydir, missing_npydir = {}, {}
for e in range(num_experiments):
npydir[e] = params["experiment"][e]["npy_vol_dir"]

if not os.path.exists(npydir[e]):
missing_npydir[e] = npydir[e]
for dir in dirnames:
os.makedirs(os.path.join(npydir[e], dir))
else:
for dir in dirnames:
dirpath = os.path.join(npydir[e], dir)
if (not os.path.exists(dirpath)) or (len(os.listdir(dirpath)) == 0):
missing_npydir[e] = npydir[e]
os.makedirs(dirpath, exist_ok=True)

# samples = processing.remove_samples_npy(npydir, samples, params)
missing_samples = np.array([samp for samp in samples if int(samp.split("_")[0]) in list(missing_npydir.keys())])
if len(missing_samples) != 0:
print("{} npy files for experiments {} are missing.".format(len(missing_samples), list(missing_npydir.keys())))

vids = {}
for e in range(num_experiments):
vids = processing.initialize_vids(params, datadict, e, vids, pathonly=True)
else:
print("No missing npy files. Ready for training.")

samples = processing.remove_samples_npy(npydir, samples, params)
else:
# Initialize video objects
vids = {}
Expand Down Expand Up @@ -778,6 +800,67 @@ def dannce_train(params: Dict):
# mono conversion will happen from RGB npy files, and the generator
# needs to b aware that the npy files contain RGB content
params["chan_num"] = params["n_channels_in"]

if len(missing_samples) != 0:
valid_params = {
"dim_in": (
params["crop_height"][1] - params["crop_height"][0],
params["crop_width"][1] - params["crop_width"][0],
),
"n_channels_in": params["n_channels_in"],
"batch_size": 1,
"n_channels_out": params["new_n_channels_out"],
"out_scale": params["sigma"],
"crop_width": params["crop_width"],
"crop_height": params["crop_height"],
"vmin": params["vmin"],
"vmax": params["vmax"],
"nvox": params["nvox"],
"interp": params["interp"],
"depth": params["depth"],
"channel_combo": None,
"mode": "coordinates",
"camnames": camnames,
"immode": params["immode"],
"shuffle": False,
"rotation": False,
"vidreaders": vids,
"distort": True,
"expval": True,
"crop_im": False,
"chunks": total_chunks,
"mono": params["mono"],
"mirror": params["mirror"],
"predict_flag": False,
"norm_im": False
}

tifdirs = []
npy_generator = generator.DataGenerator_3Dconv_torch(
missing_samples,
datadict,
datadict_3d,
cameras,
missing_samples,
com3d_dict,
tifdirs,
**valid_params
)
print("Generating missing npy files ...")
for i, samp in enumerate(missing_samples):
exp = int(samp.split("_")[0])
save_root = missing_npydir[exp]
fname = "0_{}.npy".format(samp.split("_")[1])

rr = npy_generator.__getitem__(i)
print(i, end="\r")
np.save(os.path.join(save_root, "image_volumes", fname), rr[0][0][0].astype("uint8"))
np.save(os.path.join(save_root, "grid_volumes", fname), rr[0][1][0])
np.save(os.path.join(save_root, "targets", fname), rr[1][0])

samples = processing.remove_samples_npy(npydir, samples, params)
print("{} samples ready for npy training.".format(len(samples)))

else:
# Used to initialize arrays for mono, and also in *frommem (the final generator)
params["chan_num"] = 1 if params["mono"] else params["n_channels_in"]
Expand Down