diff --git a/n2v/internals/N2V_DataGenerator.py b/n2v/internals/N2V_DataGenerator.py index 33ca496..973a5ef 100644 --- a/n2v/internals/N2V_DataGenerator.py +++ b/n2v/internals/N2V_DataGenerator.py @@ -10,7 +10,7 @@ class N2V_DataGenerator(): The 'N2V_DataGenerator' enables training and validation data generation for Noise2Void. """ - def load_imgs(self, files, dims='YX'): + def load_imgs(self, files, to32bit, dims='YX'): """ Helper to read a list of files. The images are not required to have same size, but have to be of same dimensionality. @@ -21,7 +21,8 @@ def load_imgs(self, files, dims='YX'): List of paths to tiff-files. dims : String, optional(default='YX') Dimensions of the images to read. Known dimensions are: 'TZYXC' - + to32bit: makes conversion to 32 bit optional, if set to False the original datatype is used + Returns ------- images : list(array(float)) @@ -65,7 +66,10 @@ def load_imgs(self, files, dims='YX'): else: _raise("Filetype '{}' is not supported.".format(f)) - img = imread(f).astype(np.float32) + if to32bit: + img = imread(f).astype(np.float32) + else: + img = imread(f) assert len(img.shape) == len(dims), "Number of image dimensions doesn't match 'dims'." img = np.moveaxis(img, move_axis_from, move_axis_to) @@ -80,7 +84,7 @@ def load_imgs(self, files, dims='YX'): return imgs - def load_imgs_from_directory(self, directory, filter='*.tif', dims='YX'): + def load_imgs_from_directory(self, directory, filter='*.tif', dims='YX', names_back = False, to32bit = True): """ Helper to read all files which match 'filter' from a directory. The images are not required to have same size, but have to be of same dimensionality. @@ -94,15 +98,23 @@ def load_imgs_from_directory(self, directory, filter='*.tif', dims='YX'): dims : String, optional(default='YX') Dimensions of the images to read. Known dimensions are: 'TZYXC' + names_back: if set to True, the function returns the names of the input files as list + + to32bit: makes conversion to 32 bit optional, if set to False the original datatype is used + Returns ------- images : list(array(float)) A list of the read tif-files. The images have dimensionality 'SZYXC' or 'SYXC' + (optional): files, names of the input files as list """ files = glob(join(directory, filter)) files.sort() - return self.load_imgs(files, dims=dims) + if names_back: + return files, self.load_imgs(files, to32bit, dims=dims) + else: + return self.load_imgs(files, to32bit, dims=dims) def generate_patches_from_list(self, data, num_patches_per_img=None, shape=(256, 256), augment=True, shuffle=False): @@ -142,7 +154,7 @@ def generate_patches_from_list(self, data, num_patches_per_img=None, shape=(256, return patches - def generate_patches(self, data, num_patches=None, shape=(256, 256), augment=True): + def generate_patches(self, data, num_patches=None, shape=(256, 256), augment=True, shuffle_patches = True): """ Extracts patches from 'data'. The patches can be augmented, which means they get rotated three times in XY-Plane and flipped along the X-Axis. Augmentation leads to an eight-fold increase in training data. @@ -174,17 +186,18 @@ def generate_patches(self, data, num_patches=None, shape=(256, 256), augment=Tru if augment: print("XY-Plane is not square. Omit augmentation!") - np.random.shuffle(patches) - print('Generated patches:', patches.shape) + if shuffle_patches: + np.random.shuffle(patches) + #print('Generated patches:', patches.shape) return patches def __extract_patches__(self, data, num_patches=None, shape=(256, 256), n_dims=2): if num_patches == None: patches = [] if n_dims == 2: - if data.shape[1] > shape[0] and data.shape[2] > shape[1]: - for y in range(0, data.shape[1] - shape[0], shape[0]): - for x in range(0, data.shape[2] - shape[1], shape[1]): + if data.shape[1] >= shape[0] and data.shape[2] >= shape[1]: + for y in range(0, data.shape[1] - shape[0]+1, shape[0]): + for x in range(0, data.shape[2] - shape[1]+1, shape[1]): patches.append(data[:, y:y + shape[0], x:x + shape[1]]) return np.concatenate(patches) @@ -193,10 +206,10 @@ def __extract_patches__(self, data, num_patches=None, shape=(256, 256), n_dims=2 else: print("'shape' is too big.") elif n_dims == 3: - if data.shape[1] > shape[0] and data.shape[2] > shape[1] and data.shape[3] > shape[2]: - for z in range(0, data.shape[1] - shape[0], shape[0]): - for y in range(0, data.shape[2] - shape[1], shape[1]): - for x in range(0, data.shape[3] - shape[2], shape[2]): + if data.shape[1] >= shape[0] and data.shape[2] >= shape[1] and data.shape[3] >= shape[2]: + for z in range(0, data.shape[1] - shape[0]+1, shape[0]): + for y in range(0, data.shape[2] - shape[1]+1, shape[1]): + for x in range(0, data.shape[3] - shape[2]+1, shape[2]): patches.append(data[:, z:z + shape[0], y:y + shape[1], x:x + shape[2]]) return np.concatenate(patches)