1
+ """
2
+ Nifti wrapper that includes addtional meta data. The meta data is embedded into
3
+ the Nifti as an extension.
4
+
5
+ @author: moloney
6
+ """
7
+ import json
8
+ import numpy as np
9
+ import nibabel as nb
10
+ from nibabel .nifti1 import Nifti1Extension
11
+
12
+ dcm_meta_ecode = 19
13
+
14
+ class DcmMetaExtension (Nifti1Extension ):
15
+ '''Subclass on Nifti1Extension. Handles conversion to and from json, checks
16
+ the validity of the extension, and provides access to the "meta meta" data.
17
+ '''
18
+
19
+ _req_base_keys = set (('dcmmeta_affine' ,
20
+ 'dcmmeta_slice_dim' ,
21
+ 'dcmmeta_shape' ,
22
+ 'dcmmeta_version' ,
23
+ 'global' ,
24
+ )
25
+ )
26
+
27
+ def _unmangle (self , value ):
28
+ return json .loads (value )
29
+
30
+ def _mangle (self , value ):
31
+ return json .dumps (value , indent = 4 )
32
+
33
+ def is_valid (self ):
34
+ '''Check if the extension is valid.'''
35
+ #Check for the required base keys in the json
36
+ if not self ._req_base_keys <= set (self ._content ):
37
+ return False
38
+
39
+ shape = self ._content ['dcmmeta_shape' ]
40
+
41
+ #Check the 'global' dictionary
42
+ if not set (('const' , 'slices' )) == set (self ._content ['global' ]):
43
+ return False
44
+
45
+ total_slices = [self .get_n_slices ()]
46
+ for dim_size in shape [3 :]:
47
+ total_slices .append (total_slices [- 1 ]* dim_size )
48
+ for key , vals in self ._content ['global' ]['slices' ].iteritems ():
49
+ if len (vals ) != total_slices [- 1 ]:
50
+ return False
51
+
52
+ #Check 'time' and 'vector' dictionaries if they exist
53
+ if len (shape ) > 3 :
54
+ if not 'time' in self ._content :
55
+ return False
56
+ if not set (('samples' , 'slices' )) == set (self ._content ['time' ]):
57
+ return False
58
+ for key , vals in self ._content ['time' ]['samples' ].iteritems ():
59
+ if len (vals ) != shape [3 ]:
60
+ return False
61
+ for key , vals in self ._content ['time' ]['slices' ].iteritems ():
62
+ if len (vals ) != total_slices [0 ]:
63
+ return False
64
+ if len (shape ) > 4 :
65
+ if not 'vector' in self ._content :
66
+ return False
67
+ if not set (('samples' , 'slices' )) == set (self ._content ['vector' ]):
68
+ return False
69
+ for key , vals in self ._content ['time' ]['samples' ].iteritems ():
70
+ if len (vals ) != shape [4 ]:
71
+ return False
72
+ for key , vals in self ._content ['vector' ]['slices' ].iteritems ():
73
+ if len (vals ) != total_slices [1 ]:
74
+ return False
75
+
76
+ return True
77
+
78
+ def get_affine (self ):
79
+ return np .array (self ._content ['dcmmeta_affine' ])
80
+
81
+ def get_slice_dim (self ):
82
+ return self ._content ['dcmmeta_slice_dim' ]
83
+
84
+ def get_shape (self ):
85
+ return tuple (self ._content ['dcmmeta_shape' ])
86
+
87
+ def get_n_slices (self ):
88
+ return self .get_shape ()[self .get_slice_dim ()]
89
+
90
+ def get_version (self ):
91
+ return self ._content ['dcmmeta_version' ]
92
+
93
+ def to_json_file (self , path ):
94
+ '''Write out a JSON formatted text file with the extensions contents.'''
95
+ if not self .is_valid ():
96
+ raise ValueError ('The content dictionary is not valid.' )
97
+ out_file = open (path , 'w' )
98
+ out_file .write (self ._mangle (self ._content ))
99
+ out_file .close ()
100
+
101
+ @classmethod
102
+ def from_json_file (klass , path ):
103
+ '''Read in a JSON formatted text file with the extensions contents.'''
104
+ in_file = open (path )
105
+ content = in_file .read ()
106
+ in_file .close ()
107
+ result = klass (dcm_meta_ecode , content )
108
+ if not result .is_valid ():
109
+ raise ValueError ('The JSON is not valid.' )
110
+ return result
111
+
112
+ @classmethod
113
+ def from_runtime_repr (klass , runtime_repr ):
114
+ result = klass (dcm_meta_ecode , '{}' )
115
+ result ._content = runtime_repr
116
+ if not result .is_valid ():
117
+ raise ValueError ('The runtime representation is not valid.' )
118
+ return result
119
+
120
+ #Add our extension to nibabel
121
+ nb .nifti1 .extension_codes .add_codes (((dcm_meta_ecode ,
122
+ "dcmmeta" ,
123
+ DcmMetaExtension ),)
124
+ )
125
+
126
+ class NiftiWrapper (object ):
127
+ '''Wraps a nibabel.Nifti1Image object containing a DcmMetaExtension header
128
+ extension. Provides transparent access to the meta data through 'get_meta'.
129
+ Allows the Nifti to be split into sub volumes or joined with others, while
130
+ also updating the meta data appropriately.'''
131
+
132
+ def __init__ (self , nii_img ):
133
+ self .nii_img = nii_img
134
+ self ._meta_ext = None
135
+ for extension in nii_img .get_header ().extensions :
136
+ if extension .get_code () == dcm_meta_ecode :
137
+ if self ._meta_ext :
138
+ raise ValueError ("More than one DcmMetaExtension found" )
139
+ self ._meta_ext = extension
140
+ if not self ._meta_ext :
141
+ raise ValueError ("No DcmMetaExtension found." )
142
+ if not self ._meta_ext .is_valid ():
143
+ raise ValueError ("The meta extension is not valid" )
144
+
145
+ def samples_valid (self ):
146
+ '''Check if the meta data corresponding to individual time or vector
147
+ samples appears to be valid for the wrapped nifti image.'''
148
+ #Check if the slice/time/vector dimensions match
149
+ img_shape = self .nii_img .get_shape ()
150
+ meta_shape = self ._meta_ext .get_shape ()
151
+ return meta_shape [2 :] == img_shape [2 :]
152
+
153
+ def slices_valid (self ):
154
+ '''Check if the meta data corresponding to individual slices appears to
155
+ be valid for the wrapped nifti image.'''
156
+
157
+ if self ._meta_ext .get_n_slices () != self .nii_img .get_n_slices ():
158
+ return False
159
+
160
+ #Check that the affines match
161
+ return np .allclose (self .nii_img .get_affine (),
162
+ self ._meta_ext .get_affine ())
163
+
164
+ def get_meta (self , key , index = None , default = None ):
165
+ '''Return the meta data value for the provided 'key', or 'default' if
166
+ there is no such (valid) key.
167
+
168
+ If 'index' is not provided, only meta data values that are constant
169
+ across the entire data set will be considered. If 'index' is provided it
170
+ must be a valid index for the nifti voxel data, and all of the meta data
171
+ that is applicable to that index will be considered. The per-slice meta
172
+ data will only be considered if the object's 'is_aligned' method returns
173
+ True.'''
174
+
175
+ #Pull out the meta dictionary
176
+ meta_dict = self ._meta_ext .get_content ()
177
+
178
+ #First check the constant values
179
+ if key in meta_dict ['global' ]['const' ]:
180
+ return meta_dict ['global' ]['const' ][key ]
181
+
182
+ #If an index is provided check the varying values
183
+ if not index is None :
184
+ #Test if the index is valid
185
+ shape = self .nii_img .get_shape ()
186
+ if len (index ) != len (shape ):
187
+ raise IndexError ('Incorrect number of indices.' )
188
+ for dim , ind_val in enumerate (index ):
189
+ if ind_val < 0 or ind_val >= shape [dim ]:
190
+ raise IndexError ('Index is out of bounds.' )
191
+
192
+ #First try per time/vector sample values
193
+ if self .samples_valid ():
194
+ if (len (shape ) > 3 and shape [3 ] > 1 and
195
+ key in meta_dict ['time' ]['samples' ]):
196
+ return meta_dict ['time' ]['samples' ][key ][index [2 ]]
197
+ if (len (shape ) > 4 and shape [4 ] > 1 and
198
+ key in meta_dict ['vector' ]['samples' ]):
199
+ return meta_dict ['vector' ]['samples' ][key ][index [3 ]]
200
+
201
+ #Finally, if aligned, try per-slice values
202
+ if self .slices_valid ():
203
+ slice_dim = self ._meta_ext .get_slice_dim ()
204
+ if key in meta_dict ['global' ]['slices' ]:
205
+ val_idx = index [slice_dim ]
206
+ slices_per_sample = shape [slice_dim ]
207
+ for count , idx_val in enumerate (index [3 :]):
208
+ val_idx += idx_val * slices_per_sample
209
+ slices_per_sample *= shape [count + 3 ]
210
+ return meta_dict ['global' ]['slices' ][key ][val_idx ]
211
+
212
+ if self .samples_valid ():
213
+ if (len (shape ) > 3 and shape [3 ] > 1 and
214
+ key in meta_dict ['time' ]['slices' ]):
215
+ val_idx = index [slice_dim ]
216
+ return meta_dict ['time' ]['slices' ][key ][val_idx ]
217
+ elif (len (shape ) > 4 and shape [4 ] > 1 and
218
+ key in meta_dict ['vector' ]['slices' ]):
219
+ val_idx = index [slice_dim ]
220
+ val_idx += index [3 ]* shape [slice_dim ]
221
+ return meta_dict ['vector' ]['slices' ][key ][val_idx ]
222
+
223
+ return default
224
+
225
+ def split (self , dim_idx = None ):
226
+ '''Split the meta data along the index 'dim_idx', returning a list of
227
+ NiftiWrapper objects. If 'dim_idx' is None it will prefer the vector,
228
+ then time, then slice dimensions.
229
+ '''
230
+ # shape = self.nii_img.get_shape()
231
+ # slice_dim = self.nii_img.get_dim_info()[2]
232
+ #
233
+ # #If dim_idx is None, choose the vector/time/slice dim in that order
234
+ # if dim_idx is None:
235
+ # dim_idx = len(shape) - 1
236
+ # if dim_idx == 2:
237
+ # dim_idx = slice_dim
238
+ #
239
+ # data = self.nii_img.get_data()
240
+ # affine = self.nii_img.get_affine()
241
+ # header = self.nii_img.get_header()
242
+ # if dim_idx == slice_dim:
243
+ # header. #Need to unset slice specific bits of the header here.
244
+ # results = []
245
+ # slices = [slice(None)] * len(shape)
246
+ # for idx in shape[dim_idx]:
247
+ # slices[dim_idx] = idx
248
+ # split_data = data[slices].copy()
249
+ # results.append(nb.Nifti1Image(split_data,
250
+ # affine.copy(),
251
+ # header.copy()
252
+ # )
253
+ # )
254
+ #
255
+ # return results
256
+
257
+
258
+ def to_filename (self , out_path ):
259
+ if not self ._meta_ext .is_valid :
260
+ raise ValueError ("Meta extension is not valid." )
261
+ self .nii_img .to_filename (out_path )
262
+
263
+ @classmethod
264
+ def from_filename (klass , path ):
265
+ return klass (nb .load (path ))
266
+
267
+ @classmethod
268
+ def from_sequence (klass , others , dim_idx = None ):
269
+ '''Create a NiftiWrapper from a sequence of other NiftiWrappers objects.
270
+ The Nifti volumes are stacked along dim_idx in the given order.
271
+ '''
272
+
0 commit comments