77#
88### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
99"""Nonlinear transforms."""
10+
1011import warnings
1112from functools import partial
13+ from collections import namedtuple
1214import numpy as np
15+ import nibabel as nb
1316
1417from nitransforms import io
1518from nitransforms .io .base import _ensure_image
19+ from nitransforms .io .x5 import from_filename as load_x5
1620from nitransforms .interp .bspline import grid_bspline_weights , _cubic_bspline
1721from nitransforms .base import (
1822 TransformBase ,
2226)
2327from scipy .ndimage import map_coordinates
2428
29+ # Avoids circular imports
30+ try :
31+ from nitransforms ._version import __version__
32+ except ModuleNotFoundError : # pragma: no cover
33+ __version__ = "0+unknown"
34+
2535
2636class DenseFieldTransform (TransformBase ):
2737 """Represents dense field (voxel-wise) transforms."""
2838
29- __slots__ = ("_field" , "_deltas" )
39+ __slots__ = ("_field" , "_deltas" , "_is_deltas" )
3040
3141 def __init__ (self , field = None , is_deltas = True , reference = None ):
3242 """
@@ -60,14 +70,7 @@ def __init__(self, field=None, is_deltas=True, reference=None):
6070
6171 super ().__init__ ()
6272
63- if field is not None :
64- field = _ensure_image (field )
65- self ._field = np .squeeze (
66- np .asanyarray (field .dataobj ) if hasattr (field , "dataobj" ) else field
67- )
68- else :
69- self ._field = np .zeros ((* reference .shape , reference .ndim ), dtype = "float32" )
70- is_deltas = True
73+ self ._is_deltas = is_deltas
7174
7275 try :
7376 self .reference = ImageGrid (reference if reference is not None else field )
@@ -78,24 +81,44 @@ def __init__(self, field=None, is_deltas=True, reference=None):
7881 else "Reference is not a spatial image"
7982 )
8083
84+ fieldshape = (* self .reference .shape , self .reference .ndim )
85+ if field is not None :
86+ field = _ensure_image (field )
87+ self ._field = np .squeeze (
88+ np .asanyarray (field .dataobj ) if hasattr (field , "dataobj" ) else field
89+ )
90+ if fieldshape != self ._field .shape :
91+ raise TransformError (
92+ f"Shape of the field ({ 'x' .join (str (i ) for i in self ._field .shape )} ) "
93+ f"doesn't match that of the reference({ 'x' .join (str (i ) for i in fieldshape )} )"
94+ )
95+ else :
96+ self ._field = np .zeros (fieldshape , dtype = "float32" )
97+ self ._is_deltas = True
98+
8199 if self ._field .shape [- 1 ] != self .ndim :
82100 raise TransformError (
83101 "The number of components of the field (%d) does not match "
84102 "the number of dimensions (%d)" % (self ._field .shape [- 1 ], self .ndim )
85103 )
86104
87- if is_deltas :
105+ if self . _is_deltas :
88106 self ._deltas = (
89107 self ._field .copy ()
90108 ) # IMPORTANT: you don't want to update deltas
91109 # Convert from displacements (deltas) to deformations fields
92110 # (just add its origin to each delta vector)
93- self ._field += self .reference .ndcoords .T .reshape (self . _field . shape )
111+ self ._field += self .reference .ndcoords .T .reshape (fieldshape )
94112
95113 def __repr__ (self ):
96114 """Beautify the python representation."""
97115 return f"<{ self .__class__ .__name__ } [{ self ._field .shape [- 1 ]} D] { self ._field .shape [:3 ]} >"
98116
117+ @property
118+ def is_deltas (self ):
119+ """Check whether this is a displacements (``True``) or a deformation (``False``) field."""
120+ return self ._is_deltas
121+
99122 @property
100123 def ndim (self ):
101124 """Get the dimensions of the transform."""
@@ -224,22 +247,60 @@ def __eq__(self, other):
224247 True
225248
226249 """
227- _eq = np .array_equal (self ._field , other ._field )
250+ _eq = np .allclose (self ._field , other ._field )
228251 if _eq and self ._reference != other ._reference :
229252 warnings .warn ("Fields are equal, but references do not match." )
230253 return _eq
231254
255+ def to_x5 (self , metadata = None ):
256+ """Return an :class:`~nitransforms.io.x5.X5Transform` representation."""
257+ metadata = {"WrittenBy" : f"NiTransforms { __version__ } " } | (metadata or {})
258+
259+ domain = None
260+ if (reference := self .reference ) is not None :
261+ domain = io .x5 .X5Domain (
262+ grid = True ,
263+ size = getattr (reference , "shape" , (0 , 0 , 0 )),
264+ mapping = reference .affine ,
265+ coordinates = "cartesian" ,
266+ )
267+
268+ kinds = tuple ("space" for _ in range (self .ndim )) + ("vector" ,)
269+
270+ return io .x5 .X5Transform (
271+ type = "nonlinear" ,
272+ subtype = "densefield" ,
273+ representation = "displacements" if self .is_deltas else "deformations" ,
274+ metadata = metadata ,
275+ transform = self ._deltas if self .is_deltas else self ._field ,
276+ dimension_kinds = kinds ,
277+ domain = domain ,
278+ )
279+
232280 @classmethod
233281 def from_filename (cls , filename , fmt = "X5" ):
234282 _factory = {
235283 "afni" : io .afni .AFNIDisplacementsField ,
236284 "itk" : io .itk .ITKDisplacementsField ,
237285 "fsl" : io .fsl .FSLDisplacementsField ,
286+ "X5" : None ,
238287 }
239- if fmt not in _factory :
288+ fmt = fmt .upper ()
289+ if fmt not in {k .upper () for k in _factory }:
240290 raise NotImplementedError (f"Unsupported format <{ fmt } >" )
241291
242- return cls (_factory [fmt ].from_filename (filename ))
292+ if fmt == "X5" :
293+ x5_xfm = load_x5 (filename )[0 ]
294+ Domain = namedtuple ("Domain" , "affine shape" )
295+ reference = Domain (x5_xfm .domain .mapping , x5_xfm .domain .size )
296+ field = nb .Nifti1Image (x5_xfm .transform , reference .affine )
297+ return cls (
298+ field ,
299+ is_deltas = x5_xfm .representation == "displacements" ,
300+ reference = reference ,
301+ )
302+
303+ return cls (_factory [fmt .lower ()].from_filename (filename ))
243304
244305
245306load = DenseFieldTransform .from_filename
@@ -274,6 +335,24 @@ def ndim(self):
274335 """Get the dimensions of the transform."""
275336 return self ._coeffs .ndim - 1
276337
338+ @classmethod
339+ def from_filename (cls , filename , fmt = "X5" ):
340+ _factory = {
341+ "X5" : None ,
342+ }
343+ fmt = fmt .upper ()
344+ if fmt not in {k .upper () for k in _factory }:
345+ raise NotImplementedError (f"Unsupported format <{ fmt } >" )
346+
347+ x5_xfm = load_x5 (filename )[0 ]
348+ Domain = namedtuple ("Domain" , "affine shape" )
349+ reference = Domain (x5_xfm .domain .mapping , x5_xfm .domain .size )
350+
351+ coefficients = nb .Nifti1Image (x5_xfm .transform , x5_xfm .additional_parameters )
352+ return cls (coefficients , reference = reference )
353+
354+ # return cls(_factory[fmt.lower()].from_filename(filename))
355+
277356 def to_field (self , reference = None , dtype = "float32" ):
278357 """Generate a displacements deformation field from this B-Spline field."""
279358 _ref = (
@@ -295,6 +374,32 @@ def to_field(self, reference=None, dtype="float32"):
295374 field .astype (dtype ).reshape (* _ref .shape , - 1 ), reference = _ref
296375 )
297376
377+ def to_x5 (self , metadata = None ):
378+ """Return an :class:`~nitransforms.io.x5.X5Transform` representation."""
379+ metadata = {"WrittenBy" : f"NiTransforms { __version__ } " } | (metadata or {})
380+
381+ domain = None
382+ if (reference := self .reference ) is not None :
383+ domain = io .x5 .X5Domain (
384+ grid = True ,
385+ size = getattr (reference , "shape" , (0 , 0 , 0 )),
386+ mapping = reference .affine ,
387+ coordinates = "cartesian" ,
388+ )
389+
390+ kinds = tuple ("space" for _ in range (self .ndim )) + ("vector" ,)
391+
392+ return io .x5 .X5Transform (
393+ type = "nonlinear" ,
394+ subtype = "bspline" ,
395+ representation = "coefficients" ,
396+ metadata = metadata ,
397+ transform = self ._coeffs ,
398+ dimension_kinds = kinds ,
399+ domain = domain ,
400+ additional_parameters = self ._knots .affine ,
401+ )
402+
298403 def map (self , x , inverse = False ):
299404 r"""
300405 Apply the transformation to a list of physical coordinate points.
0 commit comments