Skip to content

Commit c220702

Browse files
authored
Merge pull request #233 from firedrakeproject/ReubenHill/tsfc-interp-tidy
Tidy up interpolation tsfc<->firedrake interface
2 parents 50e6aab + ed89971 commit c220702

File tree

3 files changed

+82
-15
lines changed

3 files changed

+82
-15
lines changed

tests/test_dual_evaluation.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import pytest
2+
import ufl
3+
from tsfc.finatinterface import create_element
4+
from tsfc import compile_expression_dual_evaluation
5+
6+
7+
def test_ufl_only_simple():
8+
mesh = ufl.Mesh(ufl.VectorElement("P", ufl.triangle, 1))
9+
V = ufl.FunctionSpace(mesh, ufl.FiniteElement("P", ufl.triangle, 2))
10+
v = ufl.Coefficient(V)
11+
expr = ufl.inner(v, v)
12+
W = V
13+
to_element = create_element(W.ufl_element())
14+
ast, oriented, needs_cell_sizes, coefficients, first_coeff_fake_coords, _ = compile_expression_dual_evaluation(expr, to_element, coffee=False)
15+
assert first_coeff_fake_coords is False
16+
17+
18+
def test_ufl_only_spatialcoordinate():
19+
mesh = ufl.Mesh(ufl.VectorElement("P", ufl.triangle, 1))
20+
V = ufl.FunctionSpace(mesh, ufl.FiniteElement("P", ufl.triangle, 2))
21+
x, y = ufl.SpatialCoordinate(mesh)
22+
expr = x*y - y**2 + x
23+
W = V
24+
to_element = create_element(W.ufl_element())
25+
ast, oriented, needs_cell_sizes, coefficients, first_coeff_fake_coords, _ = compile_expression_dual_evaluation(expr, to_element, coffee=False)
26+
assert first_coeff_fake_coords is True
27+
28+
29+
def test_ufl_only_from_contravariant_piola():
30+
mesh = ufl.Mesh(ufl.VectorElement("P", ufl.triangle, 1))
31+
V = ufl.FunctionSpace(mesh, ufl.FiniteElement("RT", ufl.triangle, 1))
32+
v = ufl.Coefficient(V)
33+
expr = ufl.inner(v, v)
34+
W = ufl.FunctionSpace(mesh, ufl.FiniteElement("P", ufl.triangle, 2))
35+
to_element = create_element(W.ufl_element())
36+
ast, oriented, needs_cell_sizes, coefficients, first_coeff_fake_coords, _ = compile_expression_dual_evaluation(expr, to_element, coffee=False)
37+
assert first_coeff_fake_coords is True
38+
39+
40+
def test_ufl_only_to_contravariant_piola():
41+
mesh = ufl.Mesh(ufl.VectorElement("P", ufl.triangle, 1))
42+
V = ufl.FunctionSpace(mesh, ufl.FiniteElement("P", ufl.triangle, 2))
43+
v = ufl.Coefficient(V)
44+
expr = ufl.as_vector([v, v])
45+
W = ufl.FunctionSpace(mesh, ufl.FiniteElement("RT", ufl.triangle, 1))
46+
to_element = create_element(W.ufl_element())
47+
ast, oriented, needs_cell_sizes, coefficients, first_coeff_fake_coords, _ = compile_expression_dual_evaluation(expr, to_element, coffee=False)
48+
assert first_coeff_fake_coords is True
49+
50+
51+
def test_ufl_only_shape_mismatch():
52+
mesh = ufl.Mesh(ufl.VectorElement("P", ufl.triangle, 1))
53+
V = ufl.FunctionSpace(mesh, ufl.FiniteElement("RT", ufl.triangle, 1))
54+
v = ufl.Coefficient(V)
55+
expr = ufl.inner(v, v)
56+
assert expr.ufl_shape == ()
57+
W = V
58+
to_element = create_element(W.ufl_element())
59+
assert to_element.value_shape == (2,)
60+
with pytest.raises(ValueError):
61+
ast, oriented, needs_cell_sizes, coefficients, first_coeff_fake_coords, _ = compile_expression_dual_evaluation(expr, to_element, coffee=False)

tsfc/driver.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def name_multiindex(multiindex, name):
267267
return builder.construct_kernel(kernel_name, impero_c, index_names, quad_rule)
268268

269269

270-
def compile_expression_dual_evaluation(expression, to_element, coordinates, *,
270+
def compile_expression_dual_evaluation(expression, to_element, *,
271271
domain=None, interface=None,
272272
parameters=None, coffee=False):
273273
"""Compile a UFL expression to be evaluated against a compile-time known reference element's dual basis.
@@ -276,8 +276,7 @@ def compile_expression_dual_evaluation(expression, to_element, coordinates, *,
276276
277277
:arg expression: UFL expression
278278
:arg to_element: A FInAT element for the target space
279-
:arg coordinates: the coordinate function
280-
:arg domain: optional UFL domain the expression is defined on (useful when expression contains no domain).
279+
:arg domain: optional UFL domain the expression is defined on (required when expression contains no domain).
281280
:arg interface: backend module for the kernel interface
282281
:arg parameters: parameters object
283282
:arg coffee: compile coffee kernel instead of loopy kernel
@@ -328,25 +327,29 @@ def compile_expression_dual_evaluation(expression, to_element, coordinates, *,
328327
argument_multiindices = tuple(builder.create_element(arg.ufl_element()).get_indices()
329328
for arg in arguments)
330329

331-
# Replace coordinates (if any)
332-
domain = expression.ufl_domain()
333-
if domain:
334-
assert coordinates.ufl_domain() == domain
335-
builder.domain_coordinate[domain] = coordinates
336-
builder.set_cell_sizes(domain)
330+
# Replace coordinates (if any) unless otherwise specified by kwarg
331+
if domain is None:
332+
domain = expression.ufl_domain()
333+
assert domain is not None
337334

338335
# Collect required coefficients
336+
first_coefficient_fake_coords = False
339337
coefficients = extract_coefficients(expression)
340338
if has_type(expression, GeometricQuantity) or any(fem.needs_coordinate_mapping(c.ufl_element()) for c in coefficients):
341-
coefficients = [coordinates] + coefficients
339+
# Create a fake coordinate coefficient for a domain.
340+
coords_coefficient = ufl.Coefficient(ufl.FunctionSpace(domain, domain.ufl_coordinate_element()))
341+
builder.domain_coordinate[domain] = coords_coefficient
342+
builder.set_cell_sizes(domain)
343+
coefficients = [coords_coefficient] + coefficients
344+
first_coefficient_fake_coords = True
342345
builder.set_coefficients(coefficients)
343346

344347
# Split mixed coefficients
345348
expression = ufl_utils.split_coefficients(expression, builder.coefficient_split)
346349

347350
# Translate to GEM
348351
kernel_cfg = dict(interface=builder,
349-
ufl_cell=coordinates.ufl_domain().ufl_cell(),
352+
ufl_cell=domain.ufl_cell(),
350353
argument_multiindices=argument_multiindices,
351354
index_cache={},
352355
scalar_type=parameters["scalar_type"])
@@ -431,7 +434,7 @@ def compile_expression_dual_evaluation(expression, to_element, coordinates, *,
431434
# Handle kernel interface requirements
432435
builder.register_requirements([ir])
433436
# Build kernel tuple
434-
return builder.construct_kernel(return_arg, impero_c, index_names)
437+
return builder.construct_kernel(return_arg, impero_c, index_names, first_coefficient_fake_coords)
435438

436439

437440
def lower_integral_type(fiat_cell, integral_type):

tsfc/kernel_interface/firedrake_loopy.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818

1919
# Expression kernel description type
20-
ExpressionKernel = namedtuple('ExpressionKernel', ['ast', 'oriented', 'needs_cell_sizes', 'coefficients', 'tabulations'])
20+
ExpressionKernel = namedtuple('ExpressionKernel', ['ast', 'oriented', 'needs_cell_sizes', 'coefficients', 'first_coefficient_fake_coords', 'tabulations'])
2121

2222

2323
def make_builder(*args, **kwargs):
@@ -153,12 +153,14 @@ def register_requirements(self, ir):
153153
provided by the kernel interface."""
154154
self.oriented, self.cell_sizes, self.tabulations = check_requirements(ir)
155155

156-
def construct_kernel(self, return_arg, impero_c, index_names):
156+
def construct_kernel(self, return_arg, impero_c, index_names, first_coefficient_fake_coords):
157157
"""Constructs an :class:`ExpressionKernel`.
158158
159159
:arg return_arg: loopy.GlobalArg for the return value
160160
:arg impero_c: gem.ImperoC object that represents the kernel
161161
:arg index_names: pre-assigned index names
162+
:arg first_coefficient_fake_coords: If true, the kernel's first
163+
coefficient is a constructed UFL coordinate field
162164
:returns: :class:`ExpressionKernel` object
163165
"""
164166
args = [return_arg]
@@ -173,7 +175,8 @@ def construct_kernel(self, return_arg, impero_c, index_names):
173175
loopy_kernel = generate_loopy(impero_c, args, self.scalar_type,
174176
"expression_kernel", index_names)
175177
return ExpressionKernel(loopy_kernel, self.oriented, self.cell_sizes,
176-
self.coefficients, self.tabulations)
178+
self.coefficients, first_coefficient_fake_coords,
179+
self.tabulations)
177180

178181

179182
class KernelBuilder(KernelBuilderBase):

0 commit comments

Comments
 (0)