diff --git a/.gitignore b/.gitignore index 377e104..3daa629 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ __pycache__/ /cache/ /data/ *.pt -*.vtk \ No newline at end of file +*.vtk +.DS_Store \ No newline at end of file diff --git a/deprecated/RL_env.py b/deprecated/RL_env.py deleted file mode 100644 index 023314d..0000000 --- a/deprecated/RL_env.py +++ /dev/null @@ -1,71 +0,0 @@ -# the framework to follow -import gym -from gym import spaces -from gym.spaces import Discrete, Box -import numpy as np -from dfibert.tracker.nn. import rl -from dfibert.data import HCPDataContainer, ISMRMDataContainer -hcp_data = HCPDataContainer(100307) -ismrm_data = ISMRMDataContainer() - - -class RLenv(gym.Env): - # justifying action and observation space - def __init__(self, device): - self.device = device - # 30 directions we can take - self.action_space = spaces.Discrete(30) - self.observation_space = spaces.Box(np.array([0,0,0]), np.array([x,y,z])) - self.state = np.random.rand(0,3,size=3) - self.done = Flase - - # tranform the cartesian coordinate state to the interpolated_dwi as the input of network - def get_input_state(self.state): - ras_points = hcp_data.to_ras(self.state) # Transform state to World RAS+ coordinate system - interpolated_dwi = hcp_data.get_interpolated_dwi(ras_points, ignore_outside_points=False) - input_state = interpolated_dwi - - # calculate the angle between action vector and DTI_direction_vector - def Angle_calculator(action_vector, DTI_direction_vector): - m=action_vector - n=DTI_direction_vector - l_m=np.sqrt(m.dot(m)) - l_n=np.sqrt(n.dot(n)) - dot_product=x.dot(y) - cos_=dot_product/(l_m*l_n) - angle=np.arccos(cos_) - return angle - - # action will be performed and returns calculated state and reward - def step(self, action): - # apply action - state += step_width * norm(action_vector) - - # calculate reward - if angle(action_vector, DTI_direction_vector)in range(0,pi/2): - reward = 1 - elif: angle(action_vector, DTI_direction_vector)in range(pi/2,pi): - reward = -1 - else: angle(action_vector, DTI_direction_vector)in range(pi, 2pi): - reward = -2 - - # check if episode is done - if self.state.is_out: - done = True - else: - done = False - - # return step information - return self.state, reward, done - - # reset the game and returns the observed data from the last episode - def reset(self): - # reset state - self.state = np.random.rand(0,3,size=3) - return self.state - - def close(self): - self.env.close() - - # show or render an episode - def render(self, mode="human") diff --git a/deprecated/agent.py b/deprecated/agent.py deleted file mode 100755 index d17a899..0000000 --- a/deprecated/agent.py +++ /dev/null @@ -1,92 +0,0 @@ -import numpy as np -import random -import gym -import math -from collections import namedtuple -from itertools import count -import torch -import torch.nn as nn -import torch.optim as optim -import torch.nn.functional as F - -class DQN(nn.Module): - def __init__(self, x, y, z): - super(). __init__ () - # here i use a very simple network which consist of only 2 fully connected hidden layrs and an output layer - self.fc1 == nn.linear(in_features= x * y * z, out_features=32) - self.fc2 == nn.linear(in_features=32, out_features=64) - self.out == nn.linear(in_features=64, out_features=30) - - # implement a forward pass to the network(all pytorch neural network requir forward function) - def forward(self, t): - t = t.flatten(start_dim = 1) - t = F.relu(self.fc1(t)) - t = F.relu(self.fc2(t)) - t = self.out(t) - - return t -# experience class to create instances of experience by calling namedtuple function(creating tuples with named fields) -Experience = namedtuple( - 'experience', - ('state','action','next_state','reward') -) -# ReplayMemory used to store experiences -class ReplayMemory(): - def __init__(self, capacity, batch_size=4): - self.capacity = capacity - # actually holds the stored experiences - self.memory = [] - # number of experiences will be sampled from ReplayMemory - self.batch_size = batch_size - # keep track of how many experiences we've added to memory initialize to 0 - self.push_count = 0 - - # the function used to store experience - def push(self, experience): - # check the amount of experiences already in memory is less than the memory capacity - if len(self.memory) < (self.capacity): - self.memory.append(experience) - else: - # push new experience onto the front of memory overwritting the oldest experiences - self.memory[self.add_experience % self.capacity] = experience - self.push_count += 1 - - # the function used to sample experience - def get_batch(self, batch_size): - mini_batch = random.sample(self.memory, self.batch_size) - return mini_batch - - # return a boolean value tell us whether we can sample from memory - def can_provide_sample(self, batch_size): - return len(self.memory) >= batch_size - -# category of action, exploration or exploitation -class EpsilonGreedyStrategy(): - def __init__(self, start, end, decay): - # epsilon: exploration rate, initially set to 1 - self.start = start - self.end = end - # as the agent learns more about the env,epsilon will decay by decay rate - self.decay = decay - - # returns the calculated exploration rate - def get_exploration_rate(self, current_step): - return self.end + (self.start - self.end) * \ math.exp(-1. * current_step * self.decay) - -class Agent(): - def __init__(self, stratagy, num_actions, device): - self.current_step = 0 - self.stratagy = stratagy - self.num_actions = 30 - self.device = device - - def select_action(self, state, policy_net): - rate = stratagy.get_exploration_rate(self.current_step) - self.current_step += 1 - - if rate > random.random(): - return torch.tensor([action]).to(device) # exploration - - else: - with torch.no_grad(): - return policy_net(state).argmax(dim=1).to(device) # exploitation diff --git a/deprecated/main.py b/deprecated/main.py deleted file mode 100755 index 68c124e..0000000 --- a/deprecated/main.py +++ /dev/null @@ -1,109 +0,0 @@ -import numpy as np -import random -import gym -from itertools import count -import torch -import torch.nn as nn -import torch.optim as optim -import torch.nn.functional as F - -if __name__ == "__main__": - env = gym.make("RL_env-0") - - # set parameters - num_episodes = 100 - eps_start = 1 - eps_end = 0.01 - eps_decay = 0.001 - learning_rate = 0.01 - # discount factor used in bellman equation - gamma = 0.99 - # how frequently update the target network's weights - target_update = 10 - # capacity of ReplayMemory - memory_size = 10000 - batch_size = 4 - - - # set pytorch device tell pytorch use the GPU if it's available otherwise use the CPU - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # set env using the RLenv class - env = RLenv(device) - # set stratagy to be an instance of the EpsilonGreedyStrategy class - stratagy = EpsilonGreedyStrategy(eps_start, eps_end, eps_decay) - agent = Agent(stratagy, num_actions, device, state) - memory = ReplayMemory(memory_size) - # creating policy_net and target_net by creating 2 instances of DQN class - # and pass in the input_state we get from env - # put these networks on our defined device by pytorch's to function - policy_net = DQN(env.input_state).to(device) - target_net = DQN(env.input_state).to(device) - # set the weights of target_net to be the same as those in the policy_net - target_net.load_state_dict(policy_net.state_dict()) - # set target_net into eval mode(not in traning mode) - target_net.eval() - # set optimizer equal to the Adam optimizer which accepts our policy_net parameters - # as those will be optimizing and are defined at learning rate - optimizer = optim.Adam(params=policy_net.parameters(), lr=lr) - - - - episode_durations = [] - # traning loop, iterate each episode - for episode in range(num_episodes): - env.reset() - - # nested for loop that will iterate each time step in each episode - for timestep in count(): - # for each time step, agent select action based on the current state and policy_net - # use policy_net to select action if agent exploit the env rather than exploration - action = agent.select_action(input_state, policy_net) - reward, next_state = env.step(self, action) - # then create an experience and put it onto the memory - memory.push(Experience(state, action, next_state, reward)) - state = next_state - - # check if we can get a sample from memry to train our policy_net - if memory.can_provide_sample(batch_size): - experiences = memory.sample(batch_size) - # extract all states, actions, next_states, rewards from a given experience batch - states, actions, next_states, rewards = extract_tensor(experiences) - - # Q values for the corresponding state action pairs that exteacted from experiences batch - current_q_values = QValues.get_current(policy_net,input_states,actions) - # Maximum Q values for the next_states in the batch using the target_net - next_q_values = QValues.get_next(target_net, next_input_states) - # calculate the target Qvalues - target_qvalues = (next_q_values * gamma) + rewards - # calculate the loss between the current QValues and the target QValues using mean square error as loss function - loss = F.mse_loss(current_q_values, target_qvalues.unsqueeze(1)) - optimizer.zero_grad() - # computes the gradient of the loss w.r.t all the weights and biases in the policy_net - loss.backeard() - # updates the weights and biases with the gradient computed above - optimizer.step() - - if env.done: - # if the episode is ended append current timestep to the episode_durations list - # to store how long this episode lasted - episode_durations.append(timestep) - break - - # check if we should update the weights of target network before starting the new episode - if episode % target_update = 0: - target_net.load_state_dict(policy_net.state_dict()) - - # the wole process end once it reach the number of episode - env.close() - -# Q value calculator -class QValues(): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - @staticmethod - def get_current(policy_net, states, actions): - return policy_net(states).gather(dim=1, index = actions.unsqueeze(-1)) - - @staticmethod - def get_next(target_net, next_states): - return target_net(next_states).max(dim = 1)[0].detach() diff --git a/dfibert/__init__.py b/dfibert/__init__.py index 5005329..815bc92 100644 --- a/dfibert/__init__.py +++ b/dfibert/__init__.py @@ -1,7 +1,11 @@ +"""The deepFibreTrackin (dfibert) module offers modules and classes to help with +development and evaluation of different machine learning approaches for fibre tracking +of diffusion weighted MRI data. +""" from . import cache from . import config from . import data from . import dataset from . import envs from . import tracker -from . import util \ No newline at end of file +from . import util diff --git a/dfibert/cache/__init__.py b/dfibert/cache.py similarity index 97% rename from dfibert/cache/__init__.py rename to dfibert/cache.py index b11a9e4..7a9df27 100644 --- a/dfibert/cache/__init__.py +++ b/dfibert/cache.py @@ -15,7 +15,6 @@ from dipy.io.streamline import save_vtk_streamlines, load_vtk_streamlines from dfibert.config import Config -from .exceptions import KeyNotCachedError class Cache(): """ @@ -134,11 +133,12 @@ def get(self, key): Raises ------ - KeyNotCachedError + LookupError This exception is thrown if no object is assigned to given key. """ if not self.in_cache(key): - raise KeyNotCachedError(key) + raise LookupError("""The key {} isn't cached (anymore). + Check if key is cached with in_cache(key).""".format(key)) self.objects[key]["last_accessed"] = int(time.time()*1000.0) filename = self.objects[key]["filename"] @@ -195,8 +195,7 @@ def clear(self): """ Clears the whole cache. """ - keys = [key for key in self.objects] - for key in keys: + for key in self.objects: self.remove(key) def save_configuration(self): diff --git a/dfibert/cache/exceptions.py b/dfibert/cache/exceptions.py deleted file mode 100644 index 34c9c83..0000000 --- a/dfibert/cache/exceptions.py +++ /dev/null @@ -1,20 +0,0 @@ -class KeyNotCachedError(Exception): - """ - This error is thrown if the given key cannot be mapped to a data record. - - Attributes - ---------- - key: str - The key the cache was unable to retrieve. - """ - - def __init__(self, key): - """ - Parameters - ---------- - key: str - The key the cache was unable to retrieve. - """ - self.key = key - super().__init__("""The key {} isn't cached (anymore). - Check if key is cached with in_cache(key).""") diff --git a/dfibert/config/__init__.py b/dfibert/config.py similarity index 75% rename from dfibert/config/__init__.py rename to dfibert/config.py index 882015d..60edf9b 100644 --- a/dfibert/config/__init__.py +++ b/dfibert/config.py @@ -4,10 +4,10 @@ import configparser import atexit -from .exceptions import PathAlreadySetError +_UNSET = configparser._UNSET -class Config(): +class Config(configparser.ConfigParser): """ The Configuration can be used to retrieve configuration parameters or to set them and their default values. @@ -36,18 +36,12 @@ class Config(): Attributes ---------- - config: ConfigParser - The real config parser this class is based on. - It is not advisable to use it directly, use the wrapper functions instead for - correct behaviour. is_immutable: bool The bool indicating wether a change of the configuration is possible in software or not. path: str The current path of the active cache. """ - - config = None - _UNSET = configparser._UNSET + _config = None _CONFIGURATION_FILE = "config.ini" @classmethod def get_config(cls): @@ -60,9 +54,9 @@ def get_config(cls): Config The current active config. """ - if not cls.config: - cls.config = Config(path=cls._CONFIGURATION_FILE) - return cls.config + if not cls._config: + cls._config = Config(path=cls._CONFIGURATION_FILE) + return cls._config @classmethod def set_path(cls, path): """ @@ -76,11 +70,12 @@ def set_path(cls, path): Raises ------ - PathAlreadySetError + RuntimeError Error is thrown if the configuration is already initialized. """ - if cls.config: - raise PathAlreadySetError(path) from None + if cls._config: + raise RuntimeError("Can't set the config path to {} because the config was already \ + initialised with path {}.".format(path, cls._config.get_path())) cls._CONFIGURATION_FILE = path def __init__(self, path): @@ -92,17 +87,16 @@ def __init__(self, path): path : str The path the config uses to load. """ - self.config = configparser.ConfigParser() - self.config.optionxform = str - self.config.read(path) + super().__init__() + self.optionxform = str + self.read(path) - if (not self.config.has_section("configuration") or - not self.config.has_option("configuration", "immutableConfiguration")): + if not self.has_option("configuration", "immutableConfiguration"): self.set("configuration", "immutableConfiguration", "no") - if not self.config.has_option("configuration", "addDefaults"): + if not self.has_option("configuration", "addDefaults"): self.set("configuration", "addDefaults", "yes") - self.is_immutable = self.config.getboolean("configuration", "immutableConfiguration") + self.is_immutable = self.getboolean("configuration", "immutableConfiguration", fallback="False") self.path = path atexit.register(self.save_configuration) @@ -119,9 +113,8 @@ def _handle_add_default(self, section, option, fallback): fallback : str The fallback value of this option as string, regardless of real type. """ - if (fallback is not self._UNSET and (not self.config.has_section(section) - or not self.config.has_option(section, option)) - and self.config.getboolean("configuration", "addDefaults")): + if (fallback is not _UNSET and not self.has_option(section, option) + and self.getboolean("configuration", "addDefaults")): self.set(section, option, fallback) def set(self, section, option, value=None): @@ -138,11 +131,11 @@ def set(self, section, option, value=None): value : str, optional The value of the option, by default None. """ - if not section in self.config: - self.config[section] = {} - self.config.set(section, option, value) + if not section in self: + self[section] = {} + super().set(section, option, value) - def get(self, section, option, fallback=_UNSET): + def get(self, section, option, *args,fallback=_UNSET, **kwargs): """Gets the configuration option with specified parameters. `fallback` has to be a string. @@ -162,8 +155,8 @@ def get(self, section, option, fallback=_UNSET): The option to get """ self._handle_add_default(section, option, fallback) - return self.config.get(section, option, fallback=fallback) - def getint(self, section, option, fallback=_UNSET): + return super().get(section, option, *args, fallback=fallback, **kwargs) + def getint(self, section, option, *args, fallback=_UNSET, **kwargs): """Gets the configuration option with specified parameters as integer. `fallback` has to be a string. @@ -183,8 +176,8 @@ def getint(self, section, option, fallback=_UNSET): The option to retrieve as int. """ self._handle_add_default(section, option, fallback) - return self.config.getint(section, option, fallback=fallback) - def getfloat(self, section, option, fallback=_UNSET): + return super().getint(section, option, *args, fallback=_UNSET, **kwargs) + def getfloat(self, section, option, *args, fallback=_UNSET, **kwargs): """Gets the configuration option with specified parameters as float. `fallback` has to be a string. @@ -204,8 +197,8 @@ def getfloat(self, section, option, fallback=_UNSET): The option to retrieve as float. """ self._handle_add_default(section, option, fallback) - return self.config.getfloat(section, option, fallback=fallback) - def getboolean(self, section, option, fallback=_UNSET): + return super().getfloat(section, option, *args, fallback=_UNSET, **kwargs) + def getboolean(self, section, option, *args, fallback=_UNSET, **kwargs): """Gets the configuration option with specified parameters as boolean. `fallback` has to be a string. @@ -225,7 +218,7 @@ def getboolean(self, section, option, fallback=_UNSET): The option to retrieve as boolean. """ self._handle_add_default(section, option, fallback) - return self.config.getboolean(section, option, fallback=fallback) + return super().getboolean(section, option, *args, fallback=_UNSET, **kwargs) def get_path(self): @@ -246,4 +239,4 @@ def save_configuration(self): """ if not self.is_immutable: with open(self.path, 'w') as configfile: - self.config.write(configfile) + self.write(configfile) diff --git a/dfibert/config/exceptions.py b/dfibert/config/exceptions.py deleted file mode 100644 index b00bd23..0000000 --- a/dfibert/config/exceptions.py +++ /dev/null @@ -1,26 +0,0 @@ -# exception classes -class PathAlreadySetError(Exception): - """ - This Exception is thrown if the Config is already initialized - and your code tries to set the path. - - Attributes - ---------- - path : str - The path your code was trying to set. - current_path : str - The actual path the configuration was initialized with. - """ - - def __init__(self, path): - """ - Parameters - ---------- - path : str - The path your code was trying to set. - """ - - self.path = path - self.current_path = Config.get_config().get_path() - super().__init__("""Path of config file already set to \"{}\". - Setting it to \"{}\" failed.""".format(self.current_path, path)) \ No newline at end of file diff --git a/dfibert/data/__init__.py b/dfibert/data/__init__.py index d5b85ab..300861a 100644 --- a/dfibert/data/__init__.py +++ b/dfibert/data/__init__.py @@ -2,9 +2,9 @@ The data module is handling all kinds of DWI-data. Use this as a starting point to represent your loaded DWI-scan. -This module provides methods helping you to implement datasets, +This module provides methods helping you to implement datasets, environments and all other kinds of modules with the requirement -to work directly with the data. +to work directly with the data. """ import os @@ -13,29 +13,31 @@ import torch +import dipy.reconst.dti as dti from dipy.core.gradients import gradient_table from dipy.io import read_bvals_bvecs from dipy.denoise.localpca import localpca from dipy.denoise.pca_noise_estimate import pca_noise_estimate from dipy.align.reslice import reslice from dipy.segment.mask import median_otsu -from scipy.interpolate import RegularGridInterpolator - from dipy.tracking.streamline import interpolate_vector_3d, interpolate_scalar_3d -import dipy.reconst.dti as dti + +from scipy.interpolate import RegularGridInterpolator import numpy as np + import nibabel as nb from nibabel.affines import apply_affine from dfibert.config import Config -from dfibert.data.exceptions import (DeviceNotRetrievableError, DataContainerNotLoadableError, DWIAlreadyCroppedError, - DWIAlreadyNormalizedError, PointOutsideOfDWIError) +from dfibert.data.exceptions import (DeviceNotRetrievableError, DataContainerNotLoadableError, + DWIAlreadyCroppedError, DWIAlreadyNormalizedError, + PointOutsideOfDWIError) class RawData(SimpleNamespace): """ This class represents the raw loaded data, providing attributes to access it. - + You should mainly see it as part of an DataContainer, which provides helpful methods to manipulate it or access (interpolated, processed) values @@ -60,6 +62,11 @@ class RawData(SimpleNamespace): b0: ndarray The b0 image usable for normalization etc. """ + def __init__(self): + self.bvals, self.bvecs, self.img = None, None, None + self.t1, self.gtab, self.dwi = None, None, None + self.aff, self.binarymask, self.b0 = None, None, None + self.fa = None class DataContainer(): """ @@ -85,7 +92,8 @@ class DataContainer(): To inherit the `DataContainer` class, you are advised to use the following function: `_retrieve_data(self, file_names, denoise=False, b0_threshold=None)` - This reads the properties of the given path based on the filenames and denoises the image, if applicable. + This reads the properties of the given path based on the filenames and + denoises the image, if applicable. Then it returns a RawData object. which is automatically called in the constructor. @@ -249,8 +257,9 @@ def get_interpolated_dwi(self, points, postprocessing=None, ignore_outside_point The shape of the input points will be retained for the return array, only the last dimension will be changed from 3 to the (interpolated) DWI-size accordingly. - - If you provide a postprocessing method, the interpolated data is then fed through this postprocessing option. + + If you provide a postprocessing method, the interpolated data is then fed through this + postprocessing option. Parameters ---------- @@ -271,24 +280,25 @@ def get_interpolated_dwi(self, points, postprocessing=None, ignore_outside_point shape = points.shape new_shape = (*shape[:-1], -1) points = points.reshape(-1, 3) - - condition = ((points[:, 0] < 0) + (points[:, 0] >= self.data.dwi.shape[0]) + # OR - (points[:, 1] < 0) + (points[:, 1] >= self.data.dwi.shape[1]) + - (points[:, 2] < 0) + (points[:, 2] >= self.data.dwi.shape[2])) - - a, = np.nonzero(condition) # np.nonzero returns tuple (a) - if len(a) > 0 and not ignore_outside_points: - raise PointOutsideOfDWIError(self, self.to_ras(points), self.to_ras(points[a])) - - points[a] = np.zeros(3) # set the points being outside to inside points - - result = self.interpolator(points[not a]) + + condition = ((points[:, 0] < 0) + (points[:, 0] >= self.data.dwi.shape[0]) + # OR + (points[:, 1] < 0) + (points[:, 1] >= self.data.dwi.shape[1]) + + (points[:, 2] < 0) + (points[:, 2] >= self.data.dwi.shape[2])) + + affected_indices, = np.nonzero(condition) # np.nonzero returns tuple (a) + if len(affected_indices) > 0 and not ignore_outside_points: + raise PointOutsideOfDWIError(self, self.to_ras(points), + self.to_ras(points[affected_indices])) + + points[affected_indices] = np.zeros(3) # set the points being outside to inside points + + result = self.interpolator(points[not affected_indices]) if postprocessing is not None: - result = postprocessing(result, self.data.b0, - self.data.bvecs, + result = postprocessing(result, self.data.b0, + self.data.bvecs, self.data.bvals) - result[a, :] = 0 # overwrite their interpolated value + result[affected_indices, :] = 0 # overwrite their interpolated value result = result.reshape(new_shape) return result @@ -297,8 +307,8 @@ def crop(self, b_value=None, max_deviation=None, ignore_already_cropped=False): """Crops the dataset based on B-value. This function crops the DWI-Image based on B-Value. - Pay attention to the fact that every value deviating more than `max_deviation` from the specified `b_value` - will be irretrievably removed in the object. + Pay attention to the fact that every value deviating more than `max_deviation` + from the specified `b_value` will be irretrievably removed in the object. Parameters ---------- @@ -358,7 +368,7 @@ def normalize(self): Raises ------ DWIAlreadyCroppedError - If the DWI is already cropped, normalization doesn't make much sense anymore. + If the DWI is already cropped, normalization doesn't make much sense anymore. Thus this is prevented. Returns @@ -397,7 +407,7 @@ def normalize(self): def generate_fa(self): """Generates the FA Values for DataContainer. - Normalization is required. + Normalization is required. It is recommended to call the routine ahead of cropping, such that the FA values make sense, but it is not prohibited @@ -407,7 +417,7 @@ def generate_fa(self): Fractional anisotropy (FA) calculated from cached eigenvalues. """ if self.options.cropped: - warnings.warn("""You are generating the fa values from already cropped DWI. + warnings.warn("""You are generating the fa values from already cropped DWI. You typically want to generate_fa() before you crop the data.""") dti_model = dti.TensorModel(self.data.gtab, fit_method='LS') dti_fit = dti_model.fit(self.data.dwi) @@ -420,7 +430,7 @@ def get_fa(self): ------- ndarray Fractional anisotropy (FA) calculated from cached eigenvalues. - + See Also -------- generate_fa: The method generating the fa values which are returned here. diff --git a/dfibert/data/exceptions.py b/dfibert/data/exceptions.py index b72d38d..1a98816 100644 --- a/dfibert/data/exceptions.py +++ b/dfibert/data/exceptions.py @@ -1,3 +1,4 @@ +"Exceptions for the data submodule" class DeviceNotRetrievableError(Exception): """ Exception thrown if get_device is called on non-CUDA tensor. @@ -102,11 +103,12 @@ def __init__(self, data_container): """Parameters ---------- data_container : DataContainer - The DataContainer which is already normalized. + The DataContainer which is already normalized. """ self.data_container = data_container - super().__init__("The DWI of the DataContainer {id} is already normalized. ".format(id=data_container.id)) + super().__init__("The DWI of the DataContainer {id} is already normalized. " + .format(id=data_container.id)) class PointOutsideOfDWIError(Exception): """ @@ -141,4 +143,4 @@ def __init__(self, data_container, points, affected_points): super().__init__(("While parsing {no_points} points for further processing, " "it became apparent that {aff} of the points " "doesn't lay inside of DataContainer '{id}'.") - .format(no_points=len(points), id=data_container.id, aff=affected_points)) \ No newline at end of file + .format(no_points=len(points), id=data_container.id, aff=affected_points)) diff --git a/dfibert/data/postprocessing.py b/dfibert/data/postprocessing.py index dfbef3d..4b2bc3e 100644 --- a/dfibert/data/postprocessing.py +++ b/dfibert/data/postprocessing.py @@ -10,10 +10,9 @@ import numpy as np from dipy.core.sphere import Sphere from dipy.reconst.shm import real_sym_sh_mrtrix, smooth_pinv -from dipy.data import get_sphere from dfibert.config import Config -from dfibert.util import get_2D_sphere +from dfibert.util import get_2d_sphere, get_sphere_from_param def raw(): """Does no resampling. @@ -47,11 +46,13 @@ def spherical_harmonics(sh_order=None, smooth=None): def _wrapper(dwi, _b0, bvecs, _bvals): with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=UserWarning) # TODO - look if bvecs should be normalized, then this can be removed + warnings.filterwarnings("ignore", category=UserWarning) + # TODO - look if bvecs should be normalized, then this can be removed + raw_sphere = Sphere(xyz=bvecs) - real_sh, _, n = real_sym_sh_mrtrix(sh_order, raw_sphere.theta, raw_sphere.phi) - l = -n * (n + 1) + real_sh, _, harmonics_order = real_sym_sh_mrtrix(sh_order, raw_sphere.theta, raw_sphere.phi) + l = -harmonics_order * (harmonics_order + 1) inv_b = smooth_pinv(real_sh, np.sqrt(smooth) * l) data_sh = np.dot(dwi, inv_b.T) @@ -80,17 +81,11 @@ def resample(directions=None, sh_order=None, smooth=None, mean_centering=None, s if sphere is None: sphere = config.get("ResamplingOptions", "sphere", fallback="repulsion100") - if isinstance(sphere, Sphere): - rsphere = sphere - sphere = "custom" - else: - rsphere = get_sphere(sphere) - if directions is not None: - rsphere = Sphere(xyz=directions) - real_sh, _, _ = real_sym_sh_mrtrix(sh_order, rsphere.theta, rsphere.phi) + sphere_name, real_sphere = get_sphere_from_param(sphere, directions) + real_sh, _, _ = real_sym_sh_mrtrix(sh_order, real_sphere.theta, real_sphere.phi) - def _wrapper(dwi, b0, bvecs, bvals): - data_sh = spherical_harmonics(sh_order=sh_order, smooth=smooth)(dwi, b0, bvecs, bvals) + def _wrapper(dwi, b0_vals, bvecs, bvals): + data_sh = spherical_harmonics(sh_order=sh_order, smooth=smooth)(dwi, b0_vals, bvecs, bvals) data_resampled = np.dot(data_sh, real_sh.T) @@ -101,7 +96,7 @@ def _wrapper(dwi, b0, bvecs, bvals): data_resampled[idx] -= means return data_resampled _wrapper.id = ("resample-{sphere}-sh-order-{sh}-smooth-{sm}-mean_centering-{mc}" - .format(sphere=sphere, sh=sh_order, sm=smooth, mc=mean_centering)) + .format(sphere=sphere_name, sh=sh_order, sm=smooth, mc=mean_centering)) return _wrapper def res100(sh_order=None, smooth=None, mean_centering=None): @@ -121,7 +116,7 @@ def res100(sh_order=None, smooth=None, mean_centering=None): return resample(sh_order=sh_order, smooth=smooth, mean_centering=mean_centering, sphere="repulsion100") -def resample2D(sh_order=None, smooth=None, mean_centering=None, no_thetas=None, no_phis=None): +def resample2d(sh_order=None, smooth=None, mean_centering=None, no_thetas=None, no_phis=None): """Resamples the value to directions with the 2D sphere. Just a shortcut for the `resample` option with 2D sphere. @@ -129,7 +124,7 @@ def resample2D(sh_order=None, smooth=None, mean_centering=None, no_thetas=None, See Also -------- resample: the function this is based on. - dfibert.util.get_2D_sphere: the function the 2D sphere is generated with. + dfibert.util.get_2d_sphere: the function the 2D sphere is generated with. Returns ------- @@ -137,7 +132,7 @@ def resample2D(sh_order=None, smooth=None, mean_centering=None, no_thetas=None, A function with `id` attribute, which parses dwi accordingly. """ func = resample(sh_order=sh_order, smooth=smooth, mean_centering=mean_centering, - sphere=get_2D_sphere(no_phis=no_phis, no_thetas=no_thetas)) + sphere=get_2d_sphere(no_phis=no_phis, no_thetas=no_thetas)) func.id = ("resample-2Dsphere-{nt}x{np}-sh-order-{sh}-smooth-{sm}-mean_centering-{mc}" .format(nt=no_thetas, np=no_phis, sh=sh_order, sm=smooth, mc=mean_centering)) return func diff --git a/dfibert/dataset/__init__.py b/dfibert/dataset/__init__.py index 84bb80d..5e72926 100644 --- a/dfibert/dataset/__init__.py +++ b/dfibert/dataset/__init__.py @@ -14,10 +14,12 @@ from dfibert.config import Config from dfibert.util import get_reference_orientation, rotation_from_vectors, get_grid +from dfibert.data.exceptions import DeviceNotRetrievableError from .exceptions import WrongDatasetTypePassedError, FeatureShapesNotEqualError + class MovableData(): """ This class can be used to make classes handling multiple tensors more easily movable. @@ -76,7 +78,7 @@ def _get_tensors(self): """ tensors = {} for key, value in vars(self).items(): - if isinstance(value, torch.Tensor) or isinstance(value, MovableData): + if isinstance(value, (torch.Tensor, MovableData)): tensors[key] = value return tensors @@ -206,7 +208,7 @@ class BaseDataset(MovableData): """The base class for Datasets in this library. It extends `MovableData`. - + Attributes ---------- device: torch.device, optional @@ -214,7 +216,8 @@ class BaseDataset(MovableData): data_container: DataContainer The DataContainer the dataset is based on id: str - An ID representing this Dataset. This is not unique to any instance, but it consists of parameters and used dataset. + An ID representing this Dataset. This is not unique to any instance, + but it consists of parameters and used dataset. Methods ------- @@ -250,32 +253,35 @@ def __init__(self, data_container, device=None): class IterableDataset(BaseDataset, torch.utils.data.Dataset): + "A dataset which we are able to iterate over" def __init__(self, data_container, device=None): BaseDataset.__init__(self, data_container, device=device) torch.utils.data.Dataset.__init__(self) def __len__(self): - if type(self) is IterableDataset: - raise NotImplementedError() from None + raise NotImplementedError() from None def __getitem__(self, index): - if type(self) is IterableDataset: - raise NotImplementedError() from None + raise NotImplementedError() from None class SaveableDataset(IterableDataset): + "A dataset which we can save into a file" def __init__(self, data_container, device=None): - IterableDataset.__init__(self,data_container, device=device) - + super().__init__(data_container, device=device) + def _get_variable_elements_data(self): + assert len(self) > 0 lengths = np.zeros(len(self), dtype=int) + in_shape, out_shape = None, None for i, (inp, out) in enumerate(self): assert len(inp) == len(out) lengths[i] = len(inp) - return lengths, inp.shape[1:], out.shape[1:] - - def saveToPath(self, path): + in_shape, out_shape = inp.shape[1:], out.shape[1:] + return lengths, in_shape, out_shape + def save_to_path(self, path): + "Saves the saveable dataset to a path" os.makedirs(path, exist_ok=True) lengths, in_shape, out_shape = self._get_variable_elements_data() print(lengths) @@ -284,25 +290,31 @@ def saveToPath(self, path): in_shape=tuple([data_length] + list(in_shape)) out_shape=tuple([data_length] + list(out_shape)) - inp_memmap = np.memmap(os.path.join(path, 'input.npy'), dtype='float32', shape=in_shape, mode='w+') - out_memmap = np.memmap(os.path.join(path, 'output.npy'), dtype='float32', shape=out_shape, mode='w+') - + inp_memmap = np.memmap(os.path.join(path, 'input.npy'), dtype='float32', + shape=in_shape, mode='w+') + out_memmap = np.memmap(os.path.join(path, 'output.npy'), dtype='float32', + shape=out_shape, mode='w+') + idx = 0 - assert (len(self) == len(lengths)) - for i in range(len(self)): - inp,out = self[i] - print(i, ": ", inp.shape, " l " ,lengths[i], " - ",lengths.shape) - assert(len(inp) == lengths[i]) - inp_memmap[idx:(idx + lengths[i])] = inp.numpy() - out_memmap[idx:(idx + lengths[i])] = out.numpy() - idx = idx + lengths[i] + assert len(self) == len(lengths) + for i, ((inp, out), length) in enumerate(zip(self, lengths)): + assert len(inp) == length + inp_memmap[idx:(idx + length)] = inp.numpy() + out_memmap[idx:(idx + length)] = out.numpy() + idx = idx + length print("{}/{}".format(i, len(lengths)), end="\r") np.save(os.path.join(path, 'lengths.npy'), lengths) with open(os.path.join(path, 'info.json'), 'w') as infofile: json.dump({"id": self.id, "input_shape":in_shape, "output_shape":out_shape}, infofile) + def __len__(self): + raise NotImplementedError() from None + + def __getitem__(self, index): + raise NotImplementedError() from None class LoadedDataset(IterableDataset): + "Represents a dataset loaded from a file" def __init__(self, path, device=None, passSingleElements=False): IterableDataset.__init__(self, None, device=device) self.path = path @@ -313,7 +325,8 @@ def __init__(self, path, device=None, passSingleElements=False): inp_shape = tuple(info_data["input_shape"]) out_shape = tuple(info_data["output_shape"]) - self.feature_shapes = np.prod(info_data["input_shape"][1:]), np.prod(info_data["output_shape"][1:]) + self.feature_shapes = (np.prod(info_data["input_shape"][1:]), + np.prod(info_data["output_shape"][1:])) if not passSingleElements: self.sl_lengths = np.load(os.path.join(self.path, 'lengths.npy')) @@ -321,34 +334,42 @@ def __init__(self, path, device=None, passSingleElements=False): self.sl_lengths = np.ones((inp_shape[0])) self.sl_start_indices = np.append(0, np.cumsum(self.sl_lengths)) - self.inp_memmap = np.memmap(os.path.join(self.path, 'input.npy'), dtype='float32', shape=inp_shape, mode='r') - self.out_memmap = np.memmap(os.path.join(self.path, 'output.npy'), dtype='float32', shape=out_shape, mode='r') - + self.inp_memmap = np.memmap(os.path.join(self.path, 'input.npy'), + dtype='float32', shape=inp_shape, mode='r') + self.out_memmap = np.memmap(os.path.join(self.path, 'output.npy'), + dtype='float32', shape=out_shape, mode='r') + def __len__(self): return len(self.sl_lengths) def __getitem__(self, index): - inp = torch.from_numpy(self.inp_memmap[self.sl_start_indices[index]:self.sl_start_indices[index+1]]).to(self.device) - out = torch.from_numpy(self.out_memmap[self.sl_start_indices[index]:self.sl_start_indices[index+1]]).to(self.device) + inp = (torch.from_numpy( + self.inp_memmap[self.sl_start_indices[index]:self.sl_start_indices[index+1]]) + .to(self.device)) + out = (torch.from_numpy( + self.out_memmap[self.sl_start_indices[index]:self.sl_start_indices[index+1]]) + .to(self.device)) return (inp, out) def get_feature_shapes(self): + "Returns the feature shapes of the dataset" return self.feature_shapes class ConcatenatedDataset(SaveableDataset): + "A Concatenated Dataset built from multiple existing datasets" def __init__(self, datasets, device=None): - IterableDataset.__init__(self, None, device=device) + super().__init__(None, device=device) self.id = self.id + "[" self.__lens = [0] - for index, ds in enumerate(datasets): - if not isinstance(ds, IterableDataset): - raise WrongDatasetTypePassedError(self, ds, + for index, dataset in enumerate(datasets): + if not isinstance(dataset, IterableDataset): + raise WrongDatasetTypePassedError(self, dataset, ("Dataset {} doesn't inherit IterableDataset. " - "It is {} ").format(index, type(ds)) + "It is {} ").format(index, type(dataset)) ) from None - ds.to(self.device) - self.id = self.id + ds.id + ", " - self.__lens.append(len(ds) + self.__lens[-1]) + dataset.to(self.device) + self.id = self.id + dataset.id + ", " + self.__lens.append(len(dataset) + self.__lens[-1]) self.id = self.id[:-2] + "]" self.datasets = datasets self.options = SimpleNamespace() @@ -366,11 +387,12 @@ def __getitem__(self, index): return self.datasets[i][index - self.__lens[i]] def get_feature_shapes(self): + "Returns the feature shapes of the dataset" # assert that each dataset has same dataset shape (inp, out) = self.datasets[0].get_feature_shapes() for i in range(1, len(self.datasets)): (inp2, out2) = self.datasets[i].get_feature_shapes() - if (not torch.all(torch.tensor(inp).eq(torch.tensor(inp2))) or + if (not torch.all(torch.tensor(inp).eq(torch.tensor(inp2))) or not torch.all(torch.tensor(out).eq(torch.tensor(out2)))): raise FeatureShapesNotEqualError(i, (inp, out), (inp2, out2)) return (inp, out) @@ -395,10 +417,10 @@ def to(self, *args, **kwargs): return self class StreamlineDataset(SaveableDataset): - + "A dataset representing a list of streamlines" def __init__(self, tracker, data_container, processing, device=None, append_reverse=None, online_caching=None): - IterableDataset.__init__(self, data_container, device=device) + super().__init__(data_container, device=device) self.streamlines = tracker.get_streamlines() self.id = self.id + "-{}-(".format(processing.id) + tracker.id + ")" config = Config.get_config() @@ -415,11 +437,11 @@ def __init__(self, tracker, data_container, processing, if online_caching: self.cache = [None] * len(self) self.feature_shapes = None - + def _get_variable_elements_data(self): lengths = np.zeros(len(self) , dtype=int) - for i, sl in enumerate(self.streamlines): - lengths[i] = len(sl) + for i, streamline in enumerate(self.streamlines): + lengths[i] = len(streamline) if self.options.append_reverse: lengths[len(self.streamlines):] = lengths[:len(self.streamlines)] (inp, out) = self[0] @@ -435,14 +457,13 @@ def __getitem__(self, index): if self.options.online_caching and self.cache[index] is not None: return self.cache[index] (inp, output) = self._calculate_item(index) - inp = torch.from_numpy(inp).to(device=self.device, dtype=torch.float32) # TODO work on dtypes + inp = torch.from_numpy(inp).to(device=self.device, dtype=torch.float32) output = torch.from_numpy(output).to(device=self.device, dtype=torch.float32) if self.options.online_caching: self.cache[index] = (inp, output) - return self.cache[index] - else: - return (inp, output) + + return (inp, output) def _calculate_item(self, index): streamline = self._get_streamline(index) @@ -461,6 +482,7 @@ def _get_streamline(self, index): def get_feature_shapes(self): + "Returns the feature shapes of the dataset" if self.feature_shapes is None: dwi, next_dir = self[0] # assert that every type of data processing maintains same shape @@ -475,53 +497,53 @@ def get_feature_shapes(self): def cuda(self, device=None, non_blocking=False, memory_format=torch.preserve_format): if not self.options.online_caching: - return + return self dwi = None - for index, el in enumerate(self.cache): - if el is None: + for index, element in enumerate(self.cache): + if element is None: continue - dwi, next_dir = el + dwi, next_dir = element dwi = dwi.cuda(device=device, non_blocking=non_blocking, memory_format=memory_format) next_dir = next_dir.cuda(device=device, non_blocking=non_blocking, memory_format=memory_format) self.cache[index] = (dwi, next_dir) if self.device == dwi.device: # move is unnecessary - return + return self if dwi is not None: self.device = dwi.device return self def cpu(self, memory_format=torch.preserve_format): if not self.options.online_caching: - return + return self dwi = None - for index, el in enumerate(self.cache): - if el is None: + for index, element in enumerate(self.cache): + if element is None: continue - dwi, next_dir = el + dwi, next_dir = element dwi = dwi.cpu(memory_format=memory_format) next_dir = next_dir.cpu(memory_format=memory_format) self.cache[index] = (dwi, next_dir) if self.device == dwi.device: # move is unnecessary - return + return self if dwi is not None: self.device = dwi.device return self def to(self, *args, **kwargs): if not self.options.online_caching: - return + return self dwi = None - for index, el in enumerate(self.cache): - if el is None: + for index, element in enumerate(self.cache): + if element is None: continue - dwi, next_dir = el + dwi, next_dir = element dwi = dwi.to(*args, **kwargs) next_dir = next_dir.to(*args, **kwargs) self.cache[index] = (dwi, next_dir) if self.device == dwi.device: # move is unnecessary - return + return self if dwi is not None: self.device = dwi.device return self diff --git a/dfibert/dataset/exceptions.py b/dfibert/dataset/exceptions.py index 5e66eb6..83d703c 100644 --- a/dfibert/dataset/exceptions.py +++ b/dfibert/dataset/exceptions.py @@ -1,3 +1,4 @@ +"Exceptions for the dataset submodule" class WrongDatasetTypePassedError(Exception): """Error thrown if `ConcatenatedDataset` retrieves wrong datasets. @@ -57,4 +58,4 @@ def __init__(self, index, s1, s2): self.index = index super().__init__(("The shape of the dataset {idx} ({s2}) " "is not equal to the base shape of the reference dataset 0 ({s1})" - ).format(idx=index, s2=s2, s1=s1)) \ No newline at end of file + ).format(idx=index, s2=s2, s1=s1)) diff --git a/dfibert/dataset/processing.py b/dfibert/dataset/processing.py index 8022139..81ae36f 100644 --- a/dfibert/dataset/processing.py +++ b/dfibert/dataset/processing.py @@ -5,19 +5,20 @@ Processing The base class for all processing instructions RegressionProcessing - The basic processing, calculates direction vectors out of streamlines and interpolates DWI along a grid + The basic processing, calculates direction vectors out of streamlines + and interpolates DWI along a grid. ClassificationProcessing - Based on RegressionProcessing, however it reshapes the regression problem of the direction vector as a classification problem. + Based on RegressionProcessing, however it reshapes the regression problem + of the direction vector as a classification problem. """ from types import SimpleNamespace import numpy as np -import torch -from dipy.core.geometry import sphere_distance -from dipy.core.sphere import Sphere -from dipy.data import get_sphere + from dfibert.config import Config -from dfibert.util import get_reference_orientation, rotation_from_vectors, get_grid, apply_rotation_matrix_to_grid, direction_to_classification, rotation_from_vectors_p +from dfibert.util import (get_reference_orientation, get_grid, + apply_rotation_matrix_to_grid, direction_to_classification, + rotation_from_multiple_vectors, get_sphere_from_param) class Processing(): """The basic Processing class. @@ -31,7 +32,7 @@ class Processing(): calculate_item(data_container, sl, next_direction) Calculates the (input, output) tuple for a single last streamline point - The methods can work together, but they do not have to. + The methods can work together, but they do not have to. The existence of both must be guaranteed to be able to use every dataset. """ # TODO - Live Calculation for Tracker @@ -49,7 +50,7 @@ def calculate_streamline(self, data_container, streamline): ------ NotImplementedError If the Processing subclass didn't overwrite the function. - + Returns ------- tuple @@ -93,7 +94,8 @@ class RegressionProcessing(Processing): grid: numpy.ndarray The grid, precalculated for this processing option id: str - An ID representing this Dataset. This is not unique to any instance, but it consists of parameters and used dataset. + An ID representing this Dataset. + This is not unique to any instance, but it consists of parameters and used dataset. Methods ------- @@ -103,7 +105,8 @@ class RegressionProcessing(Processing): Calculates the (input, output) tuple for a single streamline point """ - def __init__(self, rotate=None, grid_dimension=None, grid_spacing=None, postprocessing=None, normalize=None, normalize_mean=None, normalize_std=None): + def __init__(self, rotate=None, grid_dimension=None, grid_spacing=None, postprocessing=None, + normalize=None, normalize_mean=None, normalize_std=None): """ If the parameters are passed as none, the value from the config.ini is used. @@ -134,7 +137,6 @@ def __init__(self, rotate=None, grid_dimension=None, grid_spacing=None, postproc if isinstance(grid_dimension, tuple): grid_dimension = np.array(grid_dimension) - if grid_spacing is None: grid_spacing = config.getfloat("GridOptions", "spacing", fallback="1.0") @@ -180,7 +182,8 @@ def __init__(self, rotate=None, grid_dimension=None, grid_spacing=None, postproc self.options.postprocessing = postprocessing self.grid = get_grid(grid_dimension) * grid_spacing - self.id = "RegressionProcessing-r{}-grid{}x{}x{}-spacing{}-postprocessing-{}".format(rotate, *grid_dimension, grid_spacing, postprocessing.id) + self.id = ("RegressionProcessing-r{}-grid{}x{}x{}-spacing{}-postprocessing-{}" + .format(rotate, *grid_dimension, grid_spacing, postprocessing.id)) def calculate_item(self, data_container, previous_sl, next_dir): """Calculates the (input, output) tuple for the last streamline point. @@ -190,7 +193,8 @@ def calculate_item(self, data_container, previous_sl, next_dir): data_container : DataContainer The DataContainer the streamline is associated with previous_sl: np.array - The previous streamline point including the point the data should be calculated for in RAS* + The previous streamline point including the point + the data should be calculated for in RAS+ next_dir: Tensor The next direction, provide a null vector [0,0,0] if it is irrelevant. @@ -200,10 +204,11 @@ def calculate_item(self, data_container, previous_sl, next_dir): The (input, output) data for the requested item. """ # create artificial next_dirs consisting of last and next dir for rot_mat calculation - next_dirs = np.concatenate(((previous_sl[1:] - previous_sl[:-1])[-1:], next_dir[np.newaxis, ...])) + next_dirs = np.concatenate(((previous_sl[1:] - previous_sl[:-1])[-1:], + next_dir[np.newaxis, ...])) # TODO - normalize direction vectors next_dirs, rot_matrix = self._apply_rot_matrix(next_dirs) - + next_dir = next_dirs[-1] rot_matrix = None if rot_matrix is None else rot_matrix[np.newaxis, -1] dwi, _ = self._get_dwi(data_container, previous_sl[np.newaxis, -1], rot_matrix=rot_matrix) @@ -225,7 +230,7 @@ def calculate_streamline(self, data_container, streamline): The DataContainer the streamline is associated with streamline: Tensor The streamline the input and output data should be calculated for - + Returns ------- tuple @@ -234,7 +239,8 @@ def calculate_streamline(self, data_container, streamline): """ next_dir = self._get_next_direction(streamline) next_dir, rot_matrix = self._apply_rot_matrix(next_dir) - dwi, _ = self._get_dwi(data_container, streamline, rot_matrix=rot_matrix, postprocessing=self.options.postprocessing) + dwi, _ = self._get_dwi(data_container, streamline, rot_matrix=rot_matrix, + postprocessing=self.options.postprocessing) if self.options.postprocessing is not None: dwi = self.options.postprocessing(dwi, data_container.data.b0, data_container.data.bvecs, @@ -245,7 +251,7 @@ def calculate_streamline(self, data_container, streamline): def _get_dwi(self, data_container, streamline, rot_matrix=None, postprocessing=None): points = self._get_grid_points(streamline, rot_matrix=rot_matrix) - dwi = data_container.get_interpolated_dwi(points, postprocessing=postprocessing) + dwi = data_container.get_interpolated_dwi(points, postprocessing=postprocessing) return dwi , points def _get_next_direction(self, streamline): @@ -262,11 +268,10 @@ def _apply_rot_matrix(self, next_dir): # rot_mat (N, 3, 3) # next dir (N, 3) rot_matrix[0] = np.eye(3) - rotation_from_vectors_p(rot_matrix[1:, :, :], reference[None, :], next_dir[:-1]) + rotation_from_multiple_vectors(rot_matrix[1:, :, :], reference[None, :], next_dir[:-1]) rot_next_dir = (rot_matrix.transpose((0,2,1)) @ next_dir[:, :, None]).squeeze(2) return rot_next_dir, rot_matrix - def _get_grid_points(self, streamline, rot_matrix=None): grid = self.grid @@ -290,7 +295,8 @@ class ClassificationProcessing(RegressionProcessing): grid: numpy.ndarray The grid, precalculated for this processing option id: str - An ID representing this Dataset. This is not unique to any instance, but it consists of parameters and used dataset. + An ID representing this Dataset. This is not unique to any instance, + but it consists of parameters and used dataset. Methods ------- @@ -325,15 +331,12 @@ def __init__(self, rotate=None, grid_dimension=None, grid_spacing=None, postproc if sphere is None: sphere = Config.get_config().get("Processing", "classificationSphere", fallback="repulsion724") - if isinstance(sphere, Sphere): - rsphere = sphere - sphere = "custom" - else: - rsphere = get_sphere(sphere) + sphere, rsphere = get_sphere_from_param(sphere) self.sphere = rsphere self.options.sphere = sphere self.id = ("ClassificationProcessing-r{}-sphere-{}-grid{}x{}x{}-spacing{}-postprocessing-{}" - .format(self.options.rotate, self.options.sphere, *self.options.grid_dimension, self.options.grid_spacing, self.options.postprocessing.id)) + .format(self.options.rotate, self.options.sphere, *self.options.grid_dimension, + self.options.grid_spacing, self.options.postprocessing.id)) def calculate_streamline(self, data_container, streamline): """Calculates the classification (input, output) tuple for a whole streamline. @@ -344,7 +347,7 @@ def calculate_streamline(self, data_container, streamline): The DataContainer the streamline is associated with streamline: Tensor The streamline the input and output data should be calculated for - + Returns ------- tuple @@ -352,7 +355,8 @@ def calculate_streamline(self, data_container, streamline): """ dwi, next_dir = RegressionProcessing.calculate_streamline(self, data_container, streamline) - classification_output = direction_to_classification(self.sphere, next_dir, include_stop=True, last_is_stop=True) + classification_output = direction_to_classification(self.sphere, next_dir, + include_stop=True, last_is_stop=True) return dwi, classification_output def calculate_item(self, data_container, previous_sl, next_dir): @@ -363,7 +367,8 @@ def calculate_item(self, data_container, previous_sl, next_dir): data_container : DataContainer The DataContainer the streamline is associated with previous_sl: np.array - The previous streamline point including the point the data should be calculated for in RAS* + The previous streamline point including the point the data + should be calculated for in RAS+ next_dir: Tensor The next direction, provide a null vector [0,0,0] if it is irrelevant. @@ -372,6 +377,8 @@ def calculate_item(self, data_container, previous_sl, next_dir): tuple The (input, output) data for the requested item. """ - dwi, next_dir = RegressionProcessing.calculate_item(data_container, previous_sl, next_dir) - classification_output = direction_to_classification(self.sphere, next_dir[None, ...], include_stop=True, last_is_stop=True).squeeze(axis=0) - return dwi, classification_output \ No newline at end of file + dwi, next_dir = super().calculate_item(data_container, previous_sl, next_dir) + classification_output = (direction_to_classification(self.sphere, next_dir[None, ...], + include_stop=True, last_is_stop=True) + .squeeze(axis=0)) + return dwi, classification_output diff --git a/dfibert/envs/RLtractEnvironment.py b/dfibert/envs/RLtractEnvironment.py deleted file mode 100755 index d4cb5d7..0000000 --- a/dfibert/envs/RLtractEnvironment.py +++ /dev/null @@ -1,143 +0,0 @@ -import os, sys - -import gym -from gym import spaces -from gym.spaces import Discrete, Box -import numpy as np - -from dipy.data import get_sphere -import torch - - -from dfibert.data.postprocessing import res100, resample -from dfibert.data import HCPDataContainer, ISMRMDataContainer, PointOutsideOfDWIError -from dfibert.tracker import StreamlinesFromFileTracker -from dfibert.util import get_grid - -from ._state import TractographyState - - - -class RLtractEnvironment(gym.Env): - def __init__(self, device, stepWidth = 1, dataset = '100307', grid_dim = [3,3,3], maxL2dist_to_terminalState = 0.1, pReferenceStreamlines = "data/HCP307200_DTI_smallSet.vtk"): - #data/HCP307200_DTI_min40.vtk => 5k streamlines - print("Loading precomputed streamlines (%s) for ID %s" % (pReferenceStreamlines, dataset)) - self.device = device - self.dataset = HCPDataContainer(dataset) - self.dataset.normalize() #normalize HCP data - - self.stepWidth = stepWidth - self.dtype = torch.FloatTensor - sphere = get_sphere("repulsion100") - self.directions = sphere.vertices - noActions, _ = self.directions.shape - self.action_space = spaces.Discrete(noActions+1)#spaces.Discrete(noActions) - self.dwi_postprocessor = resample(sphere=sphere) - self.referenceStreamline_ijk = None - self.grid = get_grid(np.array(grid_dim)) - self.maxL2dist_to_terminalState = maxL2dist_to_terminalState - self.pReferenceStreamlines = pReferenceStreamlines - - self.state = self.reset() - - self.stepCounter = 0 - self.maxSteps = 200 - - def interpolateDWIatState(self, stateCoordinates): - #TODO: maybe stay in RAS all the time then no need to transfer to IJK - ras_points = self.dataset.to_ras(stateCoordinates) # Transform state to World RAS+ coordinate system - - ras_points = self.grid + ras_points - - try: - interpolated_dwi = self.dataset.get_interpolated_dwi(ras_points, postprocessing=self.dwi_postprocessor) - except: - return None - interpolated_dwi = np.rollaxis(interpolated_dwi,3) #CxWxHxD - #interpolated_dwi = self.dtype(interpolated_dwi).to(self.device) - return interpolated_dwi - - - def step(self, action): - if(action == (self.action_space.n - 1)) or (self.stepCounter > self.maxSteps): - #print("Entering terminal state") - done = True - reward = self.rewardForTerminalState(self.state) - if reward > 0.95: - reward += (1/self.stepCounter) - else: - reward -= self.stepCounter / (self.maxSteps / 10.) - return self.state, reward, done - - ## convert discrete action into tangent vector - action_vector = self.directions[action] - - ## apply step by step length and update state accordingly - positionNextState = self.state.getCoordinate() + self.stepWidth * action_vector - nextState = TractographyState(positionNextState, self.interpolateDWIatState) - if nextState.getValue() is None: - return self.state, -10, True - - ## compute reward for new state - rewardNextState = self.rewardForState(nextState) - - ### check if we already left brain map - # => RLenv.dataset.data.binarymask.shape - # set done = True if coordinate of nextState is outside of binarymask - done = False - self.stepCounter += 1 - try: - nextState.getValue() - except PointOutsideOfDWIError: - done = True - #print("Agent left brain mask :(") - return self.state, -10, done - - - self.state = TractographyState(positionNextState, self.interpolateDWIatState) - # return step information - return nextState, rewardNextState, done - - - def rewardForState(self, state): - # In general, the reward will be negative but very close to zero if the agent is - # staying close to our reference streamline. - # Right now, this function only returns negative rewards but simply adding some threshold - # to the LeakyReLU is gonna result in positive rewards, too - # - # We will be normalising the distance wrt. to LeakyRelu activation function. - qry_pt = torch.FloatTensor(state.getCoordinate()).view(-1,3) - distance = torch.min(torch.sum( (self.referenceStreamline_ijk - qry_pt)**2, dim =1 ) + torch.sum( (self.referenceStreamline_ijk[-1,:] - qry_pt)**4 )) - return torch.tanh(-distance+5.3) + self.rewardForTerminalState(state) / 2 - - - def rewardForTerminalState(self, state): - qry_pt = torch.FloatTensor(state.getCoordinate()).view(3) - distance = torch.sum( (self.referenceStreamline_ijk[-1,:] - qry_pt)**2 ) - #return torch.where(distance < self.maxL2dist_to_terminalState, 1, 0 ) - return torch.tanh(-distance+5.3) - - - # reset the game and returns the observed data from the last episode - def reset(self): - file_sl = StreamlinesFromFileTracker(self.pReferenceStreamlines) - file_sl.track() - - tracked_streamlines = file_sl.get_streamlines() - streamline_index = np.random.randint(len(tracked_streamlines)) - #print("Reset to streamline %d/%d" % (streamline_index+1, len(tracked_streamlines))) - referenceStreamline_ras = tracked_streamlines[streamline_index] - referenceStreamline_ijk = self.dataset.to_ijk(referenceStreamline_ras) - initialPosition_ijk = referenceStreamline_ijk[0] - - self.state = TractographyState(initialPosition_ijk, self.interpolateDWIatState) - self.done = False - self.referenceStreamline_ijk = self.dtype(referenceStreamline_ijk).to(self.device) - - self.stepCounter = 0 - - return self.state - - - def render(self, mode="human"): - pass diff --git a/dfibert/envs/_state.py b/dfibert/envs/_state.py index 1b05642..ca4bbd0 100755 --- a/dfibert/envs/_state.py +++ b/dfibert/envs/_state.py @@ -1,17 +1,19 @@ -import numpy as np - +"Contains the state for the tractography gym environment" class TractographyState: - def __init__(self, coordinate, interpolFuncHandle): + "The state for the tractography gym environment" + def __init__(self, coordinate, interpolation_func): self.coordinate = coordinate - self.interpolFuncHandle = interpolFuncHandle - self.interpolatedDWI = None + self.interpolation_func = interpolation_func + self.interpolated_dwi = None - def getCoordinate(self): + def get_coordinate(self): + "Returns the coordinate of the state" return self.coordinate - def getValue(self): - if self.interpolatedDWI is None: + def get_value(self): + "Returns the state value - the interpolated dwi" + if self.interpolated_dwi is None: # interpolate DWI value at self.coordinate - self.interpolatedDWI = self.interpolFuncHandle(self.coordinate) - return self.interpolatedDWI \ No newline at end of file + self.interpolated_dwi = self.interpolation_func(self.coordinate) + return self.interpolated_dwi diff --git a/dfibert/envs/tractography.py b/dfibert/envs/tractography.py new file mode 100755 index 0000000..509db36 --- /dev/null +++ b/dfibert/envs/tractography.py @@ -0,0 +1,151 @@ +"This module contains the Reinforcement Learning Environment for Tractography" +import gym +from gym import spaces +import numpy as np + +from dipy.data import get_sphere +import torch + + +from dfibert.data.postprocessing import resample +from dfibert.data import HCPDataContainer, PointOutsideOfDWIError +from dfibert.tracker import StreamlinesFromFileTracker +from dfibert.util import get_grid + +from ._state import TractographyState + + + +class EnvTractography(gym.Env): + "The tractography gym environment" + def __init__(self, device, step_width = 1, dataset = '100307', grid_dim = None, + max_l2dist_to_terminal_state = 0.1, + reference_streamlines_path = "data/HCP307200_DTI_smallSet.vtk"): + if grid_dim is None: + grid_dim = [3,3,3] + #data/HCP307200_DTI_min40.vtk => 5k streamlines + print("Loading precomputed streamlines (%s) for ID %s" % + (reference_streamlines_path, dataset)) + self.device = device + self.dataset = HCPDataContainer(dataset) + self.dataset.normalize() #normalize HCP data + + self.step_width = step_width + self.dtype = torch.FloatTensor + sphere = get_sphere("repulsion100") + self.directions = sphere.vertices + no_actions, _ = self.directions.shape + self.action_space = spaces.Discrete(no_actions+1)#spaces.Discrete(no_actions) + self.dwi_postprocessor = resample(sphere=sphere) + self.reference_streamline_ijk = None + self.grid = get_grid(np.array(grid_dim)) + self.max_l2dist_to_terminal_state = max_l2dist_to_terminal_state + self.reference_streamlines_path = reference_streamlines_path + + self.state = self.reset() + + self.step_counter = 0 + self.max_steps = 200 + + def interpolate_dwi_at_state(self, state_coordinates): + "Interpolates the DWI values for the given state" + #TODO: maybe stay in RAS all the time then no need to transfer to IJK + ras_points = self.dataset.to_ras(state_coordinates) + # Transform state to World RAS+ coordinate system + + ras_points = self.grid + ras_points + + try: + interpolated_dwi = self.dataset.get_interpolated_dwi(ras_points, + postprocessing=self.dwi_postprocessor) + except PointOutsideOfDWIError as _: + return None + interpolated_dwi = np.rollaxis(interpolated_dwi,3) #CxWxHxD + #interpolated_dwi = self.dtype(interpolated_dwi).to(self.device) + return interpolated_dwi + + def step(self, action): + if(action == (self.action_space.n - 1)) or (self.step_counter > self.max_steps): + #print("Entering terminal state") + done = True + reward = self.reward_for_terminal_state(self.state) + if reward > 0.95: + reward += (1/self.step_counter) + else: + reward -= self.step_counter / (self.max_steps / 10.) + return self.state, reward, done + + ## convert discrete action into tangent vector + action_vector = self.directions[action] + + ## apply step by step length and update state accordingly + next_state_position = self.state.get_coordinate() + self.step_width * action_vector + next_state = TractographyState(next_state_position, self.interpolate_dwi_at_state) + if next_state.get_value() is None: + return self.state, -10, True + + ## compute reward for new state + reward_next_state = self.reward_for_state(next_state) + + ### check if we already left brain map + # => RLenv.dataset.data.binarymask.shape + # set done = True if coordinate of next_state is outside of binarymask + done = False + self.step_counter += 1 + try: + next_state.get_value() + except PointOutsideOfDWIError: + done = True + #print("Agent left brain mask :(") + return self.state, -10, done + + + self.state = TractographyState(next_state_position, self.interpolate_dwi_at_state) + # return step information + return next_state, reward_next_state, done + + + def reward_for_state(self, state): + "Returns the reward for a given state" + # In general, the reward will be negative but very close to zero if the agent is + # staying close to our reference streamline. + # Right now, this function only returns negative rewards but simply adding some threshold + # to the LeakyReLU is gonna result in positive rewards, too + # + # We will be normalising the distance wrt. to LeakyRelu activation function. + qry_pt = torch.FloatTensor(state.get_coordinate()).view(-1,3) + distance = (torch.min(torch.sum( (self.reference_streamline_ijk - qry_pt)**2, dim =1 ) + + torch.sum( (self.reference_streamline_ijk[-1,:] - qry_pt)**4 ))) + return torch.tanh(-distance+5.3) + self.reward_for_terminal_state(state) / 2 + + def reward_for_terminal_state(self, state): + "Returns the reward for a given terminal state" + qry_pt = torch.FloatTensor(state.get_coordinate()).view(3) + distance = torch.sum( (self.reference_streamline_ijk[-1,:] - qry_pt)**2 ) + #return torch.where(distance < self.max_l2dist_to_terminal_state, 1, 0 ) + return torch.tanh(-distance+5.3) + + + # reset the game and returns the observed data from the last episode + def reset(self): + file_sl = StreamlinesFromFileTracker(self.reference_streamlines_path) + file_sl.track() + + tracked_streamlines = file_sl.get_streamlines() + streamline_index = np.random.randint(len(tracked_streamlines)) + #print("Reset to streamline %d/%d" % (streamline_index+1, len(tracked_streamlines))) + reference_streamline_ras = tracked_streamlines[streamline_index] + reference_streamline_ijk = self.dataset.to_ijk(reference_streamline_ras) + initial_position_ijk = reference_streamline_ijk[0] + + self.state = TractographyState(initial_position_ijk, self.interpolate_dwi_at_state) + self.done = False + self.reference_streamline_ijk = self.dtype(reference_streamline_ijk).to(self.device) + + self.step_counter = 0 + + return self.state + + + def render(self, mode="human"): + pass diff --git a/dfibert/tracker/__init__.py b/dfibert/tracker/__init__.py index 269ea08..aad77a4 100644 --- a/dfibert/tracker/__init__.py +++ b/dfibert/tracker/__init__.py @@ -2,7 +2,9 @@ import os import random +from types import SimpleNamespace +import dipy.reconst.dti as dti from dipy.tracking.utils import random_seeds_from_mask, seeds_from_mask from dipy.tracking.local_tracking import LocalTracking from dipy.tracking.stopping_criterion import ThresholdStoppingCriterion @@ -13,12 +15,12 @@ from dipy.reconst.csdeconv import ConstrainedSphericalDeconvModel, auto_response_ssst from dipy.data import get_sphere, default_sphere from dipy.direction import peaks_from_model, DeterministicMaximumDirectionGetter -import dipy.reconst.dti as dti -from types import SimpleNamespace + from dfibert.config import Config from dfibert.cache import Cache -from .exceptions import StreamlinesAlreadyTrackedError, ISMRMStreamlinesNotCorrectError, StreamlinesNotTrackedError +from .exceptions import (StreamlinesAlreadyTrackedError, ISMRMStreamlinesNotCorrectError, + StreamlinesNotTrackedError) class Tracker(): """Universal Tracker class""" @@ -148,7 +150,8 @@ def track(self): fallback="10") fa_thr = Config.get_config().getfloat("CSDTracking", "autoResponseFaThreshold", fallback="0.7") - response, _ = auto_response_ssst(self.data.gtab, self.data.dwi, roi_radii=roi_r, fa_thr=fa_thr) + response, _ = auto_response_ssst(self.data.gtab, self.data.dwi, roi_radii=roi_r, + fa_thr=fa_thr) csd_model = ConstrainedSphericalDeconvModel(self.data.gtab, response) relative_peak_thr = Config.get_config().getfloat("CSDTracking", "relativePeakTreshold", fallback="0.5") @@ -208,7 +211,8 @@ def __init__(self, path): def track(self): Tracker.track(self) - self.streamlines = load_vtk_streamlines(self.path) # TODO catch exception if path does not exist + self.streamlines = load_vtk_streamlines(self.path) + # TODO catch exception if path does not exist class ISMRMReferenceStreamlinesTracker(Tracker): """Class representing the ISMRM 2015 Ground Truth fiber tracks.""" @@ -230,8 +234,8 @@ def track(self): for file in os.listdir(self.path): if file.endswith(".fib"): bundle_count = bundle_count + 1 - sl = load_vtk_streamlines(os.path.join(self.path, file)) - self.streamlines.extend(sl) + streamlines = load_vtk_streamlines(os.path.join(self.path, file)) + self.streamlines.extend(streamlines) if len(self.streamlines) != 200433 or bundle_count != 25: raise ISMRMStreamlinesNotCorrectError(self, self.path) if self.options.streamline_count is not None: diff --git a/dfibert/tracker/exceptions.py b/dfibert/tracker/exceptions.py index 2779c89..9afc3cf 100644 --- a/dfibert/tracker/exceptions.py +++ b/dfibert/tracker/exceptions.py @@ -1,3 +1,4 @@ +"Exceptions for the tracker submodule" class StreamlinesAlreadyTrackedError(Exception): """Error thrown if streamlines are already tracked.""" @@ -25,4 +26,4 @@ def __init__(self, tracker): self.data_container = tracker.data_container super().__init__( ("The streamlines weren't tracked yet from Dataset {id}. " "Call Tracker.track() to track the streamlines.") - .format(id=self.data_container.id)) \ No newline at end of file + .format(id=self.data_container.id)) diff --git a/dfibert/util.py b/dfibert/util.py index 94762f8..0e5651f 100644 --- a/dfibert/util.py +++ b/dfibert/util.py @@ -4,14 +4,35 @@ import numpy as np from dipy.core.sphere import Sphere from dipy.core.geometry import sphere_distance +from dipy.data import get_sphere + from .config import Config -def rotation_from_vectors_p(rot, vectors_orig, vectors_fin): +def rotation_from_multiple_vectors(rot, vectors_orig, vectors_fin): + """Calculates the rotation matrices required to rotate from one list of vectors to another. + + For the rotation of one vector to another, there are an infinit series of rotation matrices + possible. Due to axially symmetry, the rotation axis can be any vector lying in the symmetry + plane between the two vectors. Hence the axis-angle convention will be used to construct the + matrix with the rotation axis defined as the cross product of the two vectors. The rotation + angle is the arccosine of the dot product of the two unit vectors. + Given a unit vector parallel to the rotation axis, w = [x, y, z] and the rotation angle a, + the rotation matrix R is:: + | 1 + (1-cos(a))*(x*x-1) -z*sin(a)+(1-cos(a))*x*y y*sin(a)+(1-cos(a))*x*z | + R = | z*sin(a)+(1-cos(a))*x*y 1 + (1-cos(a))*(y*y-1) -x*sin(a)+(1-cos(a))*y*z | + | -y*sin(a)+(1-cos(a))*x*z x*sin(a)+(1-cos(a))*y*z 1 + (1-cos(a))*(z*z-1) | + @param rot: The Nx3x3 rotation matrix to update. + @type rot: Nx3x3 numpy array + @param vector_orig: The unrotated vector defined in the reference frame. + @type vector_orig: numpy array, dim Nx3 + @param vector_fin: The rotated vector defined in the reference frame. + @type vector_fin: numpy array, dim Nx3 + """ vectors_orig = vectors_orig / np.linalg.norm(vectors_orig, axis=1)[:, None] vectors_fin = vectors_fin / np.linalg.norm(vectors_fin, axis=1)[:, None] axes = np.cross(vectors_orig, vectors_fin) axes_lens = np.linalg.norm(axes, axis=1) - + axes_lens[axes_lens == 0] = 1 axes = axes/axes_lens[:,None] @@ -22,17 +43,17 @@ def rotation_from_vectors_p(rot, vectors_orig, vectors_fin): angles = np.arccos(np.sum(vectors_orig * vectors_fin, axis=1)) - sa = np.sin(angles) - ca = np.cos(angles) # cos - rot[:,0, 0] = 1.0 + (1.0 - ca)*(x**2 - 1.0) - rot[:,0, 1] = -z*sa + (1.0 - ca)*x*y - rot[:,0, 2] = y*sa + (1.0 - ca)*x*z - rot[:,1, 0] = z*sa+(1.0 - ca)*x*y - rot[:,1, 1] = 1.0 + (1.0 - ca)*(y**2 - 1.0) - rot[:,1, 2] = -x*sa+(1.0 - ca)*y*z - rot[:,2, 0] = -y*sa+(1.0 - ca)*x*z - rot[:,2, 1] = x*sa+(1.0 - ca)*y*z - rot[:,2, 2] = 1.0 + (1.0 - ca)*(z**2 - 1.0) + sin_angles = np.sin(angles) + cos_angles = np.cos(angles) # cos + rot[:,0, 0] = 1.0 + (1.0 - cos_angles)*(x**2 - 1.0) + rot[:,0, 1] = -z*sin_angles + (1.0 - cos_angles)*x*y + rot[:,0, 2] = y*sin_angles + (1.0 - cos_angles)*x*z + rot[:,1, 0] = z*sin_angles+(1.0 - cos_angles)*x*y + rot[:,1, 1] = 1.0 + (1.0 - cos_angles)*(y**2 - 1.0) + rot[:,1, 2] = -x*sin_angles+(1.0 - cos_angles)*y*z + rot[:,2, 0] = -y*sin_angles+(1.0 - cos_angles)*x*z + rot[:,2, 1] = x*sin_angles+(1.0 - cos_angles)*y*z + rot[:,2, 2] = 1.0 + (1.0 - cos_angles)*(z**2 - 1.0) def rotation_from_vectors(rot, vector_orig, vector_fin): @@ -74,23 +95,23 @@ def rotation_from_vectors(rot, vector_orig, vector_fin): angle = np.arccos(np.dot(vector_orig, vector_fin)) # Trig functions (only need to do this maths once!). - ca = np.cos(angle) - sa = np.sin(angle) + cos = np.cos(angle) + sin = np.sin(angle) # Calculate the rotation matrix elements. - rot[0, 0] = 1.0 + (1.0 - ca)*(x**2 - 1.0) - rot[0, 1] = -z*sa + (1.0 - ca)*x*y - rot[0, 2] = y*sa + (1.0 - ca)*x*z - rot[1, 0] = z*sa+(1.0 - ca)*x*y - rot[1, 1] = 1.0 + (1.0 - ca)*(y**2 - 1.0) - rot[1, 2] = -x*sa+(1.0 - ca)*y*z - rot[2, 0] = -y*sa+(1.0 - ca)*x*z - rot[2, 1] = x*sa+(1.0 - ca)*y*z - rot[2, 2] = 1.0 + (1.0 - ca)*(z**2 - 1.0) + rot[0, 0] = 1.0 + (1.0 - cos)*(x**2 - 1.0) + rot[0, 1] = -z*sin + (1.0 - cos)*x*y + rot[0, 2] = y*sin + (1.0 - cos)*x*z + rot[1, 0] = z*sin+(1.0 - cos)*x*y + rot[1, 1] = 1.0 + (1.0 - cos)*(y**2 - 1.0) + rot[1, 2] = -x*sin+(1.0 - cos)*y*z + rot[2, 0] = -y*sin+(1.0 - cos)*x*z + rot[2, 1] = x*sin+(1.0 - cos)*y*z + rot[2, 2] = 1.0 + (1.0 - cos)*(z**2 - 1.0) def get_reference_orientation(): """Get current reference rotation - + Returns ------- numpy.ndarray @@ -109,7 +130,7 @@ def get_reference_orientation(): ref = ref * -1 return ref -def get_2D_sphere(no_phis=None, no_thetas=None): +def get_2d_sphere(no_phis=None, no_thetas=None): """Retrieve evenly distributed 2D sphere out of phi and theta count. @@ -129,10 +150,10 @@ def get_2D_sphere(no_phis=None, no_thetas=None): no_thetas = Config.get_config().getint("2DSphereOptions", "noThetas", fallback="16") if no_phis is None: no_phis = Config.get_config().getint("2DSphereOptions", "noPhis", fallback="16") - xi = np.arange(0, np.pi, (np.pi) / no_thetas) # theta - yi = np.arange(-np.pi, np.pi, 2 * (np.pi) / no_phis) # phi + x_values = np.arange(0, np.pi, (np.pi) / no_thetas) # theta + y_values = np.arange(-np.pi, np.pi, 2 * (np.pi) / no_phis) # phi - basis = np.array(np.meshgrid(yi, xi)) + basis = np.array(np.meshgrid(y_values, x_values)) sphere = Sphere(theta=basis[0, :], phi=basis[1, :]) @@ -151,8 +172,8 @@ def get_grid(grid_dimension): numpy.ndarray The requested grid """ - (dx, dy, dz) = (grid_dimension - 1)/2 - return np.moveaxis(np.mgrid[-dx:dx+1, -dy:dy+1, -dz:dz+1], 0, 3) + (delta_x, delta_y, delta_z) = (grid_dimension - 1)/2 + return np.moveaxis(np.mgrid[-delta_x:delta_x+1, -delta_y:delta_y+1, -delta_z:delta_z+1], 0, 3) def random_split(dataset, training_part=0.9): """Retrieves a dataset from given path and splits them randomly in train and test data. @@ -177,7 +198,7 @@ def random_split(dataset, training_part=0.9): def get_mask_from_lengths(lengths): """Returns a mask for given array of lengths - + Parameters ---------- lengths: Tensor @@ -194,7 +215,7 @@ def apply_rotation_matrix_to_grid(grid, rot_matrix): Parameters ---------- grid : numpy.ndarray - The grid + The grid rot_matrix : numpy.ndarray The rotation matrix with the dimensions (N, 3, 3) @@ -203,16 +224,22 @@ def apply_rotation_matrix_to_grid(grid, rot_matrix): numpy.ndarray The grid, rotated along the rotation_matrix; Shape: (N, ...grid_dimensions) """ - return (rot_matrix.repeat(grid.size/3, axis=0) @ grid[None, ].repeat(len(rot_matrix), axis=0).reshape(-1, 3, 1)).reshape((-1, *grid.shape)) + return ((rot_matrix.repeat(grid.size/3, axis=0) @ + (grid[None, ].repeat(len(rot_matrix), axis=0).reshape(-1, 3, 1))) + .reshape((-1, *grid.shape))) -def direction_to_classification(sphere, next_dir, include_stop=False, last_is_stop=False, stop_values=None): +def direction_to_classification(sphere, next_dir, include_stop=False, + last_is_stop=False, stop_values=None): + """ + Converts the directions into appropriate classification values for the given sphere. + """ # code adapted from Benou "DeepTract",exi # https://github.com/itaybenou/DeepTract/blob/master/utils/train_utils.py sl_len = len(next_dir) loop_len = sl_len - 1 if include_stop and last_is_stop else sl_len - l = len(sphere.theta) + 1 if include_stop else len(sphere.theta) - classification_output = np.zeros((sl_len, l)) + classification_len = len(sphere.theta) + 1 if include_stop else len(sphere.theta) + classification_output = np.zeros((sl_len, classification_len)) for i in range(loop_len): if not (next_dir[i,0] == 0.0 and next_dir[i, 1] == 0.0 and next_dir[i, 2] == 0.0): labels_odf = np.exp(-1 * sphere_distance(next_dir[i, :], np.asarray( @@ -226,4 +253,15 @@ def direction_to_classification(sphere, next_dir, include_stop=False, last_is_st classification_output[-1, -1] = 1 # stop condition or if include_stop and stop_values is not None: classification_output[:,-1] = stop_values # stop values - return classification_output \ No newline at end of file + return classification_output + + +def get_sphere_from_param(sphere, directions=None): + "Given a sphere as either name or Sphere, returns a tuple of name, actual_sphere" + sphere_name = "custom" + if directions is not None: + return sphere_name, Sphere(xyz=directions) + if isinstance(sphere, Sphere): + return sphere_name, sphere + else: + return sphere, get_sphere(sphere) diff --git a/examples/env-toyground-Nico.ipynb b/examples/env-toyground-Nico.ipynb index a311b5b..28cc3c1 100755 --- a/examples/env-toyground-Nico.ipynb +++ b/examples/env-toyground-Nico.ipynb @@ -12,7 +12,7 @@ "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptim\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0moptim\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfunctional\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mRLtractEnvironment\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mRLTe\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 13\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mdfibert\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mHCPDataContainer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mISMRMDataContainer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptim\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0moptim\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfunctional\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtractography\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mRLTe\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 13\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mdfibert\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mHCPDataContainer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mISMRMDataContainer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'envs'" ] } @@ -29,7 +29,7 @@ "import torch.nn as nn\n", "import torch.optim as optim\n", "import torch.nn.functional as F\n", - "import dfibert.envs.RLtractEnvironment as RLTe\n", + "import dfibert.envs.tractography as RLTe\n", "from dfibert.data import HCPDataContainer, ISMRMDataContainer" ] }, @@ -55,7 +55,7 @@ } ], "source": [ - "RLenv = RLTe.RLtractEnvironment(device = 'cpu', grid_dim=[3,3,3])" + "RLenv = RLTe.EnvTractography(device = 'cpu', grid_dim=[3,3,3])" ] }, { @@ -82,7 +82,7 @@ "metadata": {}, "outputs": [], "source": [ - "nextState.getValue()" + "nextState.get_value()" ] }, { diff --git a/examples/train-toyground-Pia.ipynb b/examples/train-toyground-Pia.ipynb index f51106d..8c88ee1 100644 --- a/examples/train-toyground-Pia.ipynb +++ b/examples/train-toyground-Pia.ipynb @@ -15,7 +15,7 @@ "sys.path.insert(0,'..')\n", "\n", "from dfibert.tracker.nn.rl import Agent, Action_Scheduler\n", - "import dfibert.envs.RLtractEnvironment as RLTe" + "import dfibert.envs.tractography as RLTe" ] }, { @@ -50,7 +50,7 @@ "source": [ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "print(\"Init environment..\")\n", - "env = RLTe.RLtractEnvironment(device = 'cpu')\n", + "env = RLTe.EnvTractography(device = 'cpu')\n", "print(\"..done!\")\n", "n_actions = env.action_space.n\n", "#print(n_actions)" @@ -65,7 +65,7 @@ "print(\"Init agent\")\n", "#memory = ReplayMemory(size=replay_memory_size)\n", "state = env.reset()\n", - "agent = Agent(n_actions=n_actions, inp_size=state.getValue().shape, device=device, hidden=256, agent_history_length=agent_history_length, memory_size=replay_memory_size, learning_rate=learning_rate)\n", + "agent = Agent(n_actions=n_actions, inp_size=state.get_value().shape, device=device, hidden=256, agent_history_length=agent_history_length, memory_size=replay_memory_size, learning_rate=learning_rate)\n", "\n", "print(\"Init epsilon-greedy action scheduler\")\n", "action_scheduler = Action_Scheduler(num_actions=n_actions, max_steps=max_steps, eps_annealing_steps=100000, replay_memory_start_size=replay_memory_size, model=agent.main_dqn)\n", @@ -86,7 +86,7 @@ " #fill replay memory while interacting with env\n", " for episode_counter in range(max_episode_length):\n", " # get action with epsilon-greedy strategy \n", - " action = action_scheduler.get_action(step_counter, torch.FloatTensor(state.getValue()).to(device).unsqueeze(0))\n", + " action = action_scheduler.get_action(step_counter, torch.FloatTensor(state.get_value()).to(device).unsqueeze(0))\n", " \n", " next_state, reward, terminal = env.step(action)\n", "\n", @@ -107,9 +107,9 @@ "\n", "\n", " agent.replay_memory.add_experience(action=action,\n", - " state=state.getValue(),\n", + " state=state.get_value(),\n", " reward=reward,\n", - " new_state=next_state.getValue(),\n", + " new_state=next_state.get_value(),\n", " terminal=terminal)\n", "\n", "\n", @@ -146,7 +146,7 @@ " state = env.reset()\n", " eval_episode_reward = 0\n", " while eval_steps < max_episode_length:\n", - " action = action_scheduler.get_action(step_counter, torch.FloatTensor(state.getValue()).to(device).unsqueeze(0), evaluation=True)\n", + " action = action_scheduler.get_action(step_counter, torch.FloatTensor(state.get_value()).to(device).unsqueeze(0), evaluation=True)\n", "\n", " next_state, reward, terminal = env.step(action)\n", "\n", diff --git a/examples/train.py b/examples/train.py index 523ef15..7ce9da2 100644 --- a/examples/train.py +++ b/examples/train.py @@ -1,27 +1,29 @@ +"RL Training" +import os +import sys +import argparse + import torch import gym import numpy as np -import argparse -import os, sys sys.path.insert(0,'..') from dfibert.tracker.nn.rl import Agent, Action_Scheduler - -import dfibert.envs.RLtractEnvironment as RLTe +import dfibert.envs.tractography as RLTe def train(path, max_steps=3000000, replay_memory_size=20000, eps_annealing_steps=100000, agent_history_length=1, evaluate_every=20000, eval_runs=5, network_update_every=10000, max_episode_length=200, learning_rate=0.0000625): - + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print("Init environment..") - env = RLTe.RLtractEnvironment(device = 'cpu') + env = RLTe.EnvTractography(device = 'cpu') print("..done!") n_actions = env.action_space.n print("Init agent") state = env.reset() - agent = Agent(n_actions=n_actions, inp_size=state.getValue().shape, device=device, hidden=512, agent_history_length=agent_history_length, memory_size=replay_memory_size, learning_rate=learning_rate) + agent = Agent(n_actions=n_actions, inp_size=state.get_value().shape, device=device, hidden=512, agent_history_length=agent_history_length, memory_size=replay_memory_size, learning_rate=learning_rate) print("Init epsilon-greedy action scheduler") action_scheduler = Action_Scheduler(num_actions=n_actions, max_steps=max_steps, eps_annealing_steps=eps_annealing_steps, replay_memory_start_size=replay_memory_size, model=agent.main_dqn) @@ -43,7 +45,7 @@ def train(path, max_steps=3000000, replay_memory_size=20000, eps_annealing_steps #for episode_counter in range(max_episode_length): while not terminal: # get action with epsilon-greedy strategy - action = action_scheduler.get_action(step_counter, torch.FloatTensor(state.getValue()).to(device).unsqueeze(0)) + action = action_scheduler.get_action(step_counter, torch.FloatTensor(state.get_value()).to(device).unsqueeze(0)) next_state, reward, terminal = env.step(action) @@ -55,9 +57,9 @@ def train(path, max_steps=3000000, replay_memory_size=20000, eps_annealing_steps agent.replay_memory.add_experience(action=action, - state=state.getValue(), + state=state.get_value(), reward=reward, - new_state=next_state.getValue(), + new_state=next_state.get_value(), terminal=terminal) @@ -88,7 +90,7 @@ def train(path, max_steps=3000000, replay_memory_size=20000, eps_annealing_steps print("[{}] {}, {}, current eps {}".format(len(eps_rewards), step_counter, np.mean(eps_rewards[-1000:]), action_scheduler.eps_current) ) torch.save(agent.main_dqn.state_dict(), path+'/checkpoints/fibre_agent_{}_reward_{:.2f}.pth'.format(step_counter, np.mean(eps_rewards[-1000:]))) - ########## evaluation starting here + ########## evaluation starting here eval_rewards = [] agent.main_dqn.eval() for _ in range(eval_runs): @@ -97,7 +99,7 @@ def train(path, max_steps=3000000, replay_memory_size=20000, eps_annealing_steps eval_episode_reward = 0 episode_final = 0 while eval_steps < max_episode_length: - action = action_scheduler.get_action(step_counter, torch.FloatTensor(state.getValue()).to(device).unsqueeze(0), evaluation=True) + action = action_scheduler.get_action(step_counter, torch.FloatTensor(state.get_value()).to(device).unsqueeze(0), evaluation=True) next_state, reward, terminal = env.step(action) diff --git a/pylintrc b/pylintrc new file mode 100644 index 0000000..25ac030 --- /dev/null +++ b/pylintrc @@ -0,0 +1,9 @@ +[FORMAT] +# Exceptions to snake_case naming policy +good-names=i,j,k,x,y,z,to,id,t1,fa,b0 + +[TYPECHECK] + +# List of members which are set dynamically and missed by Pylint inference +# system, and so shouldn't trigger E1101 when accessed. +generated-members=numpy.*, torch.*, dipy.* \ No newline at end of file