1212from functools import partial
1313from collections import namedtuple
1414import numpy as np
15+ import nibabel as nb
1516
1617from nitransforms import io
1718from nitransforms .io .base import _ensure_image
19+ from nitransforms .io .x5 import from_filename as load_x5
1820from nitransforms .interp .bspline import grid_bspline_weights , _cubic_bspline
1921from nitransforms .base import (
2022 TransformBase ,
3436class DenseFieldTransform (TransformBase ):
3537 """Represents dense field (voxel-wise) transforms."""
3638
37- __slots__ = ("_field" , "_deltas" )
39+ __slots__ = ("_field" , "_deltas" , "_is_deltas" )
3840
3941 def __init__ (self , field = None , is_deltas = True , reference = None ):
4042 """
@@ -68,14 +70,7 @@ def __init__(self, field=None, is_deltas=True, reference=None):
6870
6971 super ().__init__ ()
7072
71- if field is not None :
72- field = _ensure_image (field )
73- self ._field = np .squeeze (
74- np .asanyarray (field .dataobj ) if hasattr (field , "dataobj" ) else field
75- )
76- else :
77- self ._field = np .zeros ((* reference .shape , reference .ndim ), dtype = "float32" )
78- is_deltas = True
73+ self ._is_deltas = is_deltas
7974
8075 try :
8176 self .reference = ImageGrid (reference if reference is not None else field )
@@ -86,24 +81,44 @@ def __init__(self, field=None, is_deltas=True, reference=None):
8681 else "Reference is not a spatial image"
8782 )
8883
84+ fieldshape = (* self .reference .shape , self .reference .ndim )
85+ if field is not None :
86+ field = _ensure_image (field )
87+ self ._field = np .squeeze (
88+ np .asanyarray (field .dataobj ) if hasattr (field , "dataobj" ) else field
89+ )
90+ if fieldshape != self ._field .shape :
91+ raise TransformError (
92+ f"Shape of the field ({ 'x' .join (str (i ) for i in self ._field .shape )} ) "
93+ f"doesn't match that of the reference({ 'x' .join (str (i ) for i in fieldshape )} )"
94+ )
95+ else :
96+ self ._field = np .zeros (fieldshape , dtype = "float32" )
97+ self ._is_deltas = True
98+
8999 if self ._field .shape [- 1 ] != self .ndim :
90100 raise TransformError (
91101 "The number of components of the field (%d) does not match "
92102 "the number of dimensions (%d)" % (self ._field .shape [- 1 ], self .ndim )
93103 )
94104
95- if is_deltas :
105+ if self . is_deltas :
96106 self ._deltas = (
97107 self ._field .copy ()
98108 ) # IMPORTANT: you don't want to update deltas
99109 # Convert from displacements (deltas) to deformations fields
100110 # (just add its origin to each delta vector)
101- self ._field += self .reference .ndcoords .T .reshape (self . _field . shape )
111+ self ._field += self .reference .ndcoords .T .reshape (fieldshape )
102112
103113 def __repr__ (self ):
104114 """Beautify the python representation."""
105115 return f"<{ self .__class__ .__name__ } [{ self ._field .shape [- 1 ]} D] { self ._field .shape [:3 ]} >"
106116
117+ @property
118+ def is_deltas (self ):
119+ """Check whether this is a displacements (``True``) or a deformation (``False``) field."""
120+ return self ._is_deltas
121+
107122 @property
108123 def ndim (self ):
109124 """Get the dimensions of the transform."""
@@ -232,7 +247,7 @@ def __eq__(self, other):
232247 True
233248
234249 """
235- _eq = np .array_equal (self ._field , other ._field )
250+ _eq = np .allclose (self ._field , other ._field )
236251 if _eq and self ._reference != other ._reference :
237252 warnings .warn ("Fields are equal, but references do not match." )
238253 return _eq
@@ -255,9 +270,9 @@ def to_x5(self, metadata=None):
255270 return io .x5 .X5Transform (
256271 type = "nonlinear" ,
257272 subtype = "densefield" ,
258- representation = "displacements" ,
273+ representation = "displacements" if self . is_deltas else "deformations" ,
259274 metadata = metadata ,
260- transform = self ._deltas ,
275+ transform = self ._deltas if self . is_deltas else self . _field ,
261276 dimension_kinds = kinds ,
262277 domain = domain ,
263278 )
@@ -275,12 +290,15 @@ def from_filename(cls, filename, fmt="X5"):
275290 raise NotImplementedError (f"Unsupported format <{ fmt } >" )
276291
277292 if fmt == "X5" :
278- from .io .x5 import from_filename as load_x5
279-
280293 x5_xfm = load_x5 (filename )[0 ]
281294 Domain = namedtuple ("Domain" , "affine shape" )
282295 reference = Domain (x5_xfm .domain .mapping , x5_xfm .domain .size )
283- return cls (x5_xfm .transform , is_deltas = True , reference = reference )
296+ field = nb .Nifti1Image (x5_xfm .transform , reference .affine )
297+ return cls (
298+ field ,
299+ is_deltas = x5_xfm .representation == "displacements" ,
300+ reference = reference ,
301+ )
284302
285303 return cls (_factory [fmt .lower ()].from_filename (filename ))
286304
@@ -317,6 +335,24 @@ def ndim(self):
317335 """Get the dimensions of the transform."""
318336 return self ._coeffs .ndim - 1
319337
338+ @classmethod
339+ def from_filename (cls , filename , fmt = "X5" ):
340+ _factory = {
341+ "X5" : None ,
342+ }
343+ fmt = fmt .upper ()
344+ if fmt not in {k .upper () for k in _factory }:
345+ raise NotImplementedError (f"Unsupported format <{ fmt } >" )
346+
347+ x5_xfm = load_x5 (filename )[0 ]
348+ Domain = namedtuple ("Domain" , "affine shape" )
349+ reference = Domain (x5_xfm .domain .mapping , x5_xfm .domain .size )
350+
351+ coefficients = nb .Nifti1Image (x5_xfm .transform , x5_xfm .additional_parameters )
352+ return cls (coefficients , reference = reference )
353+
354+ # return cls(_factory[fmt.lower()].from_filename(filename))
355+
320356 def to_field (self , reference = None , dtype = "float32" ):
321357 """Generate a displacements deformation field from this B-Spline field."""
322358 _ref = (
@@ -351,21 +387,17 @@ def to_x5(self, metadata=None):
351387 coordinates = "cartesian" ,
352388 )
353389
354- meta = metadata | {
355- "KnotsAffine" : self ._knots .affine .tolist (),
356- "KnotsShape" : self ._knots .shape ,
357- }
358-
359390 kinds = tuple ("space" for _ in range (self .ndim )) + ("vector" ,)
360391
361392 return io .x5 .X5Transform (
362393 type = "nonlinear" ,
363394 subtype = "bspline" ,
364395 representation = "coefficients" ,
365- metadata = meta ,
396+ metadata = metadata ,
366397 transform = self ._coeffs ,
367398 dimension_kinds = kinds ,
368399 domain = domain ,
400+ additional_parameters = self ._knots .affine ,
369401 )
370402
371403 def map (self , x , inverse = False ):
0 commit comments