Skip to content

Commit 4c41492

Browse files
djkapnerfcollman
authored andcommitted
Add fit estimate to tps (#123)
* clunky, but working copy of javafor computing dMtxDat * working with tests, still clunky * added stand-alone test, not dependent on pre-computed json from java client
1 parent d3900a8 commit 4c41492

File tree

4 files changed

+207
-4
lines changed

4 files changed

+207
-4
lines changed

renderapi/transform/leaf/thin_plate_spline.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,101 @@ def inverse_tform(
180180
newpts.append(npt)
181181
return np.array(newpts)
182182

183+
@staticmethod
184+
def fit(A, B, computeAffine=True):
185+
"""function to fit this transform given the corresponding sets of points A & B
186+
187+
Parameters
188+
----------
189+
A : numpy.array
190+
a Nx2 matrix of source points
191+
B : numpy.array
192+
a Nx2 matrix of destination points
193+
194+
Returns
195+
-------
196+
dMatrix : numpy.array
197+
ndims x nLm
198+
aMatrix : numpy.array
199+
ndims x ndims, affine matrix
200+
bVector : numpy.array
201+
ndims x 1, translation vector
202+
"""
203+
204+
if not all([A.shape[0] == B.shape[0], A.shape[1] == B.shape[1] == 2]):
205+
raise EstimationError(
206+
'shape mismatch! A shape: {}, B shape {}'.format(
207+
A.shape, B.shape))
208+
209+
# build displacements
210+
ndims = B.shape[1]
211+
nLm = B.shape[0]
212+
y = (B - A).flatten()
213+
214+
# compute K
215+
# tempting to matricize this, but, nLm x nLm can get big
216+
# settle for vectorize
217+
kMatrix = np.zeros((ndims * nLm, ndims * nLm))
218+
for i in range(nLm):
219+
r = np.linalg.norm(A[i, :] - A, axis=1)
220+
nrm = np.zeros_like(r)
221+
ind = np.argwhere(r > 1e-8)
222+
nrm[ind] = r[ind] * r[ind] * np.log(r[ind])
223+
kMatrix[i * ndims, 0::2] = nrm
224+
kMatrix[(i * ndims + 1)::2, 1::2] = nrm
225+
226+
# compute L
227+
lMatrix = kMatrix
228+
if computeAffine:
229+
pMatrix = np.tile(np.eye(ndims), (nLm, ndims + 1))
230+
for d in range(ndims):
231+
pMatrix[0::2, d*ndims] = A[:, d]
232+
pMatrix[1::2, d*ndims + 1] = A[:, d]
233+
lMatrix = np.zeros(
234+
(ndims * (nLm + ndims + 1), ndims * (nLm + ndims + 1)))
235+
lMatrix[
236+
0: pMatrix.shape[0],
237+
kMatrix.shape[1]: kMatrix.shape[1] + pMatrix.shape[1]] = \
238+
pMatrix
239+
pMatrix = np.transpose(pMatrix)
240+
lMatrix[
241+
kMatrix.shape[0]: kMatrix.shape[0] + pMatrix.shape[0],
242+
0: pMatrix.shape[1]] = pMatrix
243+
lMatrix[0: ndims * nLm, 0: ndims * nLm] = kMatrix
244+
y = np.append(y, np.zeros(ndims * (ndims + 1)))
245+
246+
wMatrix = np.linalg.solve(lMatrix, y)
247+
248+
dMatrix = np.reshape(wMatrix[0: ndims * nLm], (ndims, nLm), order='F')
249+
aMatrix = None
250+
bVector = None
251+
if computeAffine:
252+
aMatrix = np.reshape(
253+
wMatrix[ndims * nLm: ndims * nLm + ndims * ndims],
254+
(ndims, ndims),
255+
order='F')
256+
bVector = wMatrix[ndims * nLm + ndims * ndims:]
257+
258+
return dMatrix, aMatrix, bVector
259+
260+
def estimate(self, A, B, computeAffine=True):
261+
"""method for setting this transformation with the best fit
262+
given the corresponding points A,B
263+
Parameters
264+
----------
265+
A : numpy.array
266+
a Nx2 matrix of source points
267+
B : numpy.array
268+
a Nx2 matrix of destination points
269+
computeAffine: boolean
270+
whether to include an affine computation
271+
"""
272+
273+
self.dMtxDat, self.aMtx, self.bVec = self.fit(
274+
A, B, computeAffine=computeAffine)
275+
(self.nLm, self.ndims) = B.shape
276+
self.srcPts = np.transpose(A)
277+
183278
@property
184279
def dataString(self):
185280
header = 'ThinPlateSplineR2LogR {} {}'.format(self.ndims, self.nLm)

test/rendersettings.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
'owner': 'renderowner',
1010
'project': 'renderproject',
1111
'client_scripts': '/path/to/client_scripts'
12-
}
12+
}
1313

1414
DEFAULT_RENDER_CLIENT = dict(DEFAULT_RENDER, **{
1515
'client_script': '/path/to/client_scripts/run_ws_client.sh',
@@ -21,7 +21,7 @@
2121
'RENDER_OWNER': DEFAULT_RENDER['owner'],
2222
'RENDER_PROJECT': DEFAULT_RENDER['project'],
2323
'RENDER_CLIENT_SCRIPTS': DEFAULT_RENDER['client_scripts']
24-
}
24+
}
2525

2626
DEFAULT_RENDER_CLIENT_ENVIRONMENT_VARIABLES = dict(
2727
DEFAULT_RENDER_ENVIRONMENT_VARIABLES, **{
@@ -33,8 +33,12 @@
3333

3434
TEST_TILESPECS_FILE = os.path.join(TEST_FILES_DIR, 'tilespecs.json')
3535

36-
TEST_THINPLATESPLINE_FILE = os.path.join(TEST_FILES_DIR,
37-
'thin_plate_spline.json')
36+
TEST_THINPLATESPLINE_FILE = os.path.join(
37+
TEST_FILES_DIR,
38+
'thin_plate_spline.json')
39+
TEST_THINPLATESPLINEAFFINE_FILE = os.path.join(
40+
TEST_FILES_DIR,
41+
'thin_plate_spline_affine.json')
3842

3943
INTERPOLATED_TRANSFORM_TILESPEC = os.path.join(
4044
TEST_FILES_DIR, 'tilespec_interpolated.json')
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"type" : "leaf",
3+
"className" : "mpicbg.trakem2.transform.ThinPlateSplineTransform",
4+
"dataString" : "ThinPlateSplineR2LogR 2 100 @P4R64SLPrGq8kudjBEERFDykx/MHaYpjP2iTdKla/W/AWp5BXNVtaMBXQhhumRdE eJx9VH1UjHkULps6ks+S1IqOWkQpW9k4Xu3aPUeFxIxEUdmtlTgVq0+bmSaqpVCRYWoxanvLUFrK8qYimzmkqCWl8bErsaeRb0527/x+ve85c+3Z+ec59/bc5zy/e583AwP9n1ciy5YaOUwV6kyWZbf+HSXUuYgvh7/XDvER6p+hLkxaJdRKxC8Wi8UhsmNCzSK+yuA/f7yvj3DAn1DnIv6AP6Ee8CfUSsQf8CfULOKr/n9PH/nCmIv4vD++5v3xtRLxeX98zSK+Sp+P7/bRnrAvjLw/vub98bUS8Xl/fM0ivkqfL/jDd8J7wT74GiPvj6+ViM/742sW8VX6fMEfzg2+E94L9sHXGJWIz/vjaxbxVfp8wR/OMc4NvhPeC/ahRDweeX98zSK+Sp8v+MPfFc4xzg2+E94L9sHXGFnEV+nzBX/4O8ffFc4xzg2+E94L9sEiHo8qfb7gD/3fwd85/q5wjnFu8J3wXrAPoUbInDRbmS9yDOZKG/peG/YXcNK6ZwFmtm7c7rqQ9L7wOdyu+mh3qXk3YKjMdZsX9IMdzarKOGm9lXGo7C3MPWp0C6oGnRUvG236mZxV0fcmKGO5/Nm/jOc8zJmcs/3q8zfdmO2SpwVlRWrmpwW3mGcWpoCaMvtQF+i/+GrFowfA6321cpotzMktPH3DQWed9rvTrcz+COtI81RLTpFaVd8U/RsTnXvpfYN/IRd7LPvJ1e9TuIhp2Qb3ZH2A6+6rk5yhL9k7UbsdeDvT14+8DXMlnhXLI0HH4sozwwgm762ka8LcGk7+IO2a7/FgJlnatqSjOJTzLhiUvTU2kfGLE7VHdY4D7J4bJcrhvBVZ1W3nKoGnGO1zRg5zW47WvksGneQJ85vvMHkFL5r67vlx8vphnTuHHGYSwpTDPRPVXKBoQ3fTLAfGbprGasySU4yds/Pwyph50HcbX/bSGniRd2c4tXDyug8nJy8TgU6vw9rRIVSvPZjqbe4nel+0hxO93xM8iJ5vxWCil6UOJXolylSi5zKpiOr5Xqd6Trb0vYP30/cGBJL3dpYGkPfuan1D3rtBWkbeu/7QbvLejsqZ5L2+K03Ie+tS/iDv/drImN6jfAS9x2Jnco+L5iJyD7V9MrnHg/mL6D2sbOk9ao3pPdKayT3Kn04n99Aeiqd5SUqhedl4muSFW9NJ8lL6IZPmZfQkmpfKQJqXGzdIXgJbltG8SCJJXsJnLaN5dmqlee7R0Dwfy6B5XhBE8/xlB82zVwPNs+YHkueQC440z5vH0jybx+j0TFQ/toD+xsv3vRLBt1VVS10x7GPbJCNtOez5uU37egVF732kP9jahPIaUumciyHVkY4B/V7toahR4FvZZf3wLOyj8ta56puwZ+mI8MO6u5vaqFoZijtCaL8mk/LudNO5G3KqE+8A73utmHG3CPbxbsPjmHjYc3XFUI9CuF/j9fx5NZCL/QVrehIJrm4bR/oHZuwhPNP0q3RunTXVOVWm28dn5Seu6fY8e+2sBN39ikdG+EIusl2VDrm6vIm0cwwJ9sRYkn7FRcobnuNP5yyXU52E97BniWLq3G/gfur3l+3EkIv6vVODNjN+W2SyKxWvGLspigsLJxYSdP02lvRbaw8Q3vSwE3TO3onqaP4CXPJppuM26D/tXfSwCXju/ilB8YxffFWOYXczY+caHZ61NIGg7LAd6ZsF5BGepNKNzC3es4noZHyw0uVi5yKfJ7q8FR3J6NLluK3LyVj3fYwtHTZZ992V7ziYSlCS5kn658WTCe+uhQedc59OdfwzIEejLjDiX2Gvz0XO8XGw5/wWhwYx7L1kb9ZjJ7hD3GKTkmaK6W9If7f/Lco7uJDOxZlSnejVcOeepgN/7oO7K5Z2yM5ADo7bX3LfBLlIDCsqGAo5GRRrVD2e4tI+0i8OHUp4FzX/0LlNNlSn6xPIYeD2jXVvIJdRQ45OuQQ5tUye+XkN5HbrneC8PMix9nZSXhjFgym0/0pDeK79OWTuyHMXqtNo8S+ZhThG"
5+
}

test/test_transform.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytest
88

99
EPSILON = 0.0000000001
10+
EPSILON2 = 0.000000001
1011

1112

1213
def cross_py23_reload(module):
@@ -804,6 +805,104 @@ def test_thinplatespline():
804805
dst_pts = t.tform(src_pts)
805806
assert(dst_pts.shape == src_pts.shape)
806807

808+
# check load from json
809+
jt = t.to_dict()
810+
nt = renderapi.transform.ThinPlateSplineTransform(
811+
json=jt)
812+
assert nt == t
813+
814+
815+
def test_thinplatespline_apply():
816+
# tests some copied behavior from trakem2
817+
j = json.load(open(rendersettings.TEST_THINPLATESPLINE_FILE, 'r'))
818+
t = renderapi.transform.ThinPlateSplineTransform(
819+
dataString=j['dataString'])
820+
del t.dMtxDat
821+
pt = np.array([1.234, 45.678])
822+
npt = t.apply(pt)
823+
assert np.all(pt == npt)
824+
825+
826+
def estimate_test(jpath, computeAffine=True):
827+
# test that the estimate method can produce same results
828+
with open(jpath, 'r') as f:
829+
j = json.load(f)
830+
831+
t = renderapi.transform.ThinPlateSplineTransform(
832+
dataString=j['dataString'])
833+
# exact points
834+
src1 = np.transpose(t.srcPts)
835+
# some in-between points
836+
x = np.linspace(
837+
src1[:, 0].min(),
838+
src1[:, 0].max(),
839+
int(np.sqrt(t.nLm)) * 3)
840+
y = np.linspace(
841+
src1[:, 1].min(),
842+
src1[:, 1].max(),
843+
int(np.sqrt(t.nLm)) * 3)
844+
xt, yt = np.meshgrid(x, y)
845+
src2 = np.transpose(np.vstack((xt.flatten(), yt.flatten())))
846+
dst1_a = t.tform(src1)
847+
dst2_a = t.tform(src2)
848+
849+
# estimate
850+
t.estimate(src1, dst1_a, computeAffine=computeAffine)
851+
dst1_b = t.tform(src1)
852+
dst2_b = t.tform(src2)
853+
delta1 = np.linalg.norm(dst1_b - dst1_a, axis=1)
854+
delta2 = np.linalg.norm(dst2_b - dst2_a, axis=1)
855+
856+
assert delta1.max() < EPSILON2
857+
assert delta2.max() < EPSILON2
858+
859+
with pytest.raises(renderapi.errors.EstimationError):
860+
t.estimate(src1, dst1_a[1:, :])
861+
862+
863+
def test_thinplatespline_estimate():
864+
estimate_test(
865+
rendersettings.TEST_THINPLATESPLINE_FILE,
866+
computeAffine=False)
867+
estimate_test(
868+
rendersettings.TEST_THINPLATESPLINEAFFINE_FILE,
869+
computeAffine=True)
870+
thinplate_estimate_nojson(computeAffine=True)
871+
thinplate_estimate_nojson(computeAffine=False)
872+
873+
874+
def thinplate_estimate_nojson(computeAffine=True):
875+
# an estimate test that does not depend on pre-computed json
876+
x = np.linspace(0, 1000, 40)
877+
xt, yt = np.meshgrid(x, x)
878+
src = np.transpose(np.vstack((xt.flatten(), yt.flatten())))
879+
880+
def poly2d(src):
881+
dst = np.zeros_like(src)
882+
dx = (src[:, 0] - src[:, 0].mean()) / (src[:, 0].ptp())
883+
dy = (src[:, 1] - src[:, 1].mean()) / (src[:, 1].ptp())
884+
dst[:, 0] = src[:, 0] + 0.5 * dx * dy + 7.0 * dx * dy * dy
885+
dst[:, 1] = src[:, 1] + 0.7 * dy * dy - 4.0 * dx * dx * dy
886+
a = np.array([[1.1, -0.04], [-0.02, 0.95]])
887+
b = np.array([3.0, 4.0])
888+
dst = np.transpose(a.dot(np.transpose(dst))) + b
889+
return dst
890+
891+
dst = poly2d(src)
892+
t = renderapi.transform.ThinPlateSplineTransform()
893+
t.estimate(src, dst, computeAffine=computeAffine)
894+
895+
x = np.linspace(0, 1000, 120)
896+
xt, yt = np.meshgrid(x, x)
897+
test_src = np.transpose(np.vstack((xt.flatten(), yt.flatten())))
898+
p_dst = poly2d(test_src)
899+
t_dst = t.tform(test_src)
900+
901+
delta = np.linalg.norm(p_dst - t_dst, axis=1)
902+
# can it match the polynomial within a pixel?
903+
# for low-N, it won't
904+
assert delta.max() < 1.0
905+
807906

808907
def test_encode64():
809908
# case for Stephan's '@' character

0 commit comments

Comments
 (0)