Skip to content

Commit 6a2c5b5

Browse files
committed
refactor kernel_interface
1 parent 081db51 commit 6a2c5b5

File tree

4 files changed

+613
-204
lines changed

4 files changed

+613
-204
lines changed

tsfc/kernel_interface/common.py

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,29 @@
1+
import collections
2+
import string
3+
import operator
4+
from functools import reduce
5+
from itertools import chain
6+
17
import numpy
8+
from numpy import asarray
29

310
import coffee.base as coffee
411

12+
from ufl.utils.sequences import max_degree
13+
514
import gem
615

716
from gem.utils import cached_property
17+
import gem.impero_utils as impero_utils
818

19+
from tsfc import fem, ufl_utils
920
from tsfc.kernel_interface import KernelInterface
21+
from tsfc.finatinterface import as_fiat_cell
22+
from tsfc.logging import logger
23+
24+
from FIAT.reference_element import TensorProductCell
25+
26+
from finat.quadrature import AbstractQuadratureRule, make_quadrature
1027

1128

1229
class KernelBuilderBase(KernelInterface):
@@ -107,3 +124,246 @@ def register_requirements(self, ir):
107124
"""
108125
# Nothing is required by default
109126
pass
127+
128+
129+
class KernelBuilderMixin(object):
130+
"""Mixin for KernelBuilder classes."""
131+
132+
def compile_ufl(self, integrand, params, argument_multiindices=None):
133+
"""Compile UFL integrand.
134+
135+
:arg integrand: UFL integrand.
136+
:arg params: a dict of parameters containing quadrature info.
137+
:kwarg argument_multiindices: multiindices to use to index
138+
arguments contained in the given integrand. If None, the
139+
"true" argument multiindices (`self.argument_multiindices`)
140+
used in the `self.return_variables` are used.
141+
142+
.. note::
143+
Problem solving environments can pass any multiindices to be
144+
used in the returned gem expression, e.g.,
145+
`self.argument_multiindices_dummy`. They are then responsible
146+
for applying appropriate operations to the returned gem
147+
expression so that the indices in the resulting gem expression
148+
match those in `self.return_variables`.
149+
"""
150+
# Split Coefficients
151+
if self.coefficient_split:
152+
integrand = ufl_utils.split_coefficients(integrand, self.coefficient_split)
153+
# Compile: ufl -> gem
154+
functions = list(self.arguments) + [self.coordinate(self.integral_data.domain)] + list(self.integral_data.coefficients)
155+
_set_quad_rule(params, self.integral_data.domain.ufl_cell(), self.integral_data.integral_type, functions)
156+
quad_rule = params["quadrature_rule"]
157+
config = self.fem_config.copy()
158+
config.update(quadrature_rule=quad_rule)
159+
config['argument_multiindices'] = argument_multiindices or self.argument_multiindices
160+
expressions = fem.compile_ufl(integrand,
161+
interior_facet=self.interior_facet,
162+
**config)
163+
self.quadrature_indices.extend(quad_rule.point_set.indices)
164+
return expressions
165+
166+
def construct_integrals(self, expressions, params):
167+
mode = pick_mode(params["mode"])
168+
return mode.Integrals(expressions,
169+
params["quadrature_rule"].point_set.indices,
170+
self.argument_multiindices,
171+
params)
172+
173+
def stash_integrals(self, reps, params):
174+
mode = pick_mode(params["mode"])
175+
mode_irs = self.mode_irs
176+
mode_irs.setdefault(mode, collections.OrderedDict())
177+
for var, rep in zip(self.return_variables, reps):
178+
mode_irs[mode].setdefault(var, []).append(rep)
179+
180+
def compile_gem(self):
181+
# Finalise mode representations into a set of assignments
182+
mode_irs = self.mode_irs
183+
index_cache = self.fem_config['index_cache']
184+
185+
assignments = []
186+
for mode, var_reps in mode_irs.items():
187+
assignments.extend(mode.flatten(var_reps.items(), index_cache))
188+
189+
if assignments:
190+
return_variables, expressions = zip(*assignments)
191+
else:
192+
return_variables = []
193+
expressions = []
194+
195+
# Need optimised roots
196+
options = dict(reduce(operator.and_,
197+
[mode.finalise_options.items()
198+
for mode in mode_irs.keys()]))
199+
expressions = impero_utils.preprocess_gem(expressions, **options)
200+
201+
# Let the kernel interface inspect the optimised IR to register
202+
# what kind of external data is required (e.g., cell orientations,
203+
# cell sizes, etc.).
204+
oriented, needs_cell_sizes, tabulations = self.register_requirements(expressions)
205+
206+
# Construct ImperoC
207+
assignments = list(zip(return_variables, expressions))
208+
index_ordering = get_index_ordering(self.quadrature_indices, return_variables)
209+
try:
210+
impero_c = impero_utils.compile_gem(assignments, index_ordering, remove_zeros=True)
211+
except impero_utils.NoopError:
212+
impero_c = None
213+
return impero_c, oriented, needs_cell_sizes, tabulations
214+
215+
@cached_property
216+
def argument_multiindices_dummy(self):
217+
return tuple(tuple(gem.Index(extent=a.extent) for a in arg)
218+
for arg in self.argument_multiindices)
219+
220+
@cached_property
221+
def fem_config(self):
222+
# Map from UFL FiniteElement objects to multiindices. This is
223+
# so we reuse Index instances when evaluating the same coefficient
224+
# multiple times with the same table.
225+
#
226+
# We also use the same dict for the unconcatenate index cache,
227+
# which maps index objects to tuples of multiindices. These two
228+
# caches shall never conflict as their keys have different types
229+
# (UFL finite elements vs. GEM index objects).
230+
#
231+
# -> fem_config['index_cache']
232+
integral_type = self.integral_data.integral_type
233+
cell = self.integral_data.domain.ufl_cell()
234+
fiat_cell = as_fiat_cell(cell)
235+
integration_dim, entity_ids = lower_integral_type(fiat_cell, integral_type)
236+
return dict(interface=self,
237+
ufl_cell=cell,
238+
integral_type=integral_type,
239+
integration_dim=integration_dim,
240+
entity_ids=entity_ids,
241+
index_cache={},
242+
scalar_type=self.fem_scalar_type)
243+
244+
245+
def get_index_ordering(quadrature_indices, return_variables):
246+
split_argument_indices = tuple(chain(*[var.index_ordering()
247+
for var in return_variables]))
248+
return tuple(quadrature_indices) + split_argument_indices
249+
250+
251+
def get_index_names(quadrature_indices, argument_multiindices, index_cache):
252+
index_names = []
253+
254+
def name_index(index, name):
255+
index_names.append((index, name))
256+
if index in index_cache:
257+
for multiindex, suffix in zip(index_cache[index],
258+
string.ascii_lowercase):
259+
name_multiindex(multiindex, name + suffix)
260+
261+
def name_multiindex(multiindex, name):
262+
if len(multiindex) == 1:
263+
name_index(multiindex[0], name)
264+
else:
265+
for i, index in enumerate(multiindex):
266+
name_index(index, name + str(i))
267+
268+
name_multiindex(quadrature_indices, 'ip')
269+
for multiindex, name in zip(argument_multiindices, ['j', 'k']):
270+
name_multiindex(multiindex, name)
271+
return index_names
272+
273+
274+
def _set_quad_rule(params, cell, integral_type, functions):
275+
# Check if the integral has a quad degree attached, otherwise use
276+
# the estimated polynomial degree attached by compute_form_data
277+
try:
278+
quadrature_degree = params["quadrature_degree"]
279+
except KeyError:
280+
quadrature_degree = params["estimated_polynomial_degree"]
281+
function_degrees = [f.ufl_function_space().ufl_element().degree() for f in functions]
282+
if all((asarray(quadrature_degree) > 10 * asarray(degree)).all()
283+
for degree in function_degrees):
284+
logger.warning("Estimated quadrature degree %s more "
285+
"than tenfold greater than any "
286+
"argument/coefficient degree (max %s)",
287+
quadrature_degree, max_degree(function_degrees))
288+
if params.get("quadrature_rule") == "default":
289+
del params["quadrature_rule"]
290+
try:
291+
quad_rule = params["quadrature_rule"]
292+
except KeyError:
293+
fiat_cell = as_fiat_cell(cell)
294+
integration_dim, _ = lower_integral_type(fiat_cell, integral_type)
295+
integration_cell = fiat_cell.construct_subelement(integration_dim)
296+
quad_rule = make_quadrature(integration_cell, quadrature_degree)
297+
params["quadrature_rule"] = quad_rule
298+
299+
if not isinstance(quad_rule, AbstractQuadratureRule):
300+
raise ValueError("Expected to find a QuadratureRule object, not a %s" %
301+
type(quad_rule))
302+
303+
304+
def lower_integral_type(fiat_cell, integral_type):
305+
"""Lower integral type into the dimension of the integration
306+
subentity and a list of entity numbers for that dimension.
307+
308+
:arg fiat_cell: FIAT reference cell
309+
:arg integral_type: integral type (string)
310+
"""
311+
vert_facet_types = ['exterior_facet_vert', 'interior_facet_vert']
312+
horiz_facet_types = ['exterior_facet_bottom', 'exterior_facet_top', 'interior_facet_horiz']
313+
314+
dim = fiat_cell.get_dimension()
315+
if integral_type == 'cell':
316+
integration_dim = dim
317+
elif integral_type in ['exterior_facet', 'interior_facet']:
318+
if isinstance(fiat_cell, TensorProductCell):
319+
raise ValueError("{} integral cannot be used with a TensorProductCell; need to distinguish between vertical and horizontal contributions.".format(integral_type))
320+
integration_dim = dim - 1
321+
elif integral_type == 'vertex':
322+
integration_dim = 0
323+
elif integral_type in vert_facet_types + horiz_facet_types:
324+
# Extrusion case
325+
if not isinstance(fiat_cell, TensorProductCell):
326+
raise ValueError("{} integral requires a TensorProductCell.".format(integral_type))
327+
basedim, extrdim = dim
328+
assert extrdim == 1
329+
330+
if integral_type in vert_facet_types:
331+
integration_dim = (basedim - 1, 1)
332+
elif integral_type in horiz_facet_types:
333+
integration_dim = (basedim, 0)
334+
else:
335+
raise NotImplementedError("integral type %s not supported" % integral_type)
336+
337+
if integral_type == 'exterior_facet_bottom':
338+
entity_ids = [0]
339+
elif integral_type == 'exterior_facet_top':
340+
entity_ids = [1]
341+
else:
342+
entity_ids = list(range(len(fiat_cell.get_topology()[integration_dim])))
343+
344+
return integration_dim, entity_ids
345+
346+
347+
def pick_mode(mode):
348+
"Return one of the specialized optimisation modules from a mode string."
349+
try:
350+
from firedrake_citations import Citations
351+
cites = {"vanilla": ("Homolya2017", ),
352+
"coffee": ("Luporini2016", "Homolya2017", ),
353+
"spectral": ("Luporini2016", "Homolya2017", "Homolya2017a"),
354+
"tensor": ("Kirby2006", "Homolya2017", )}
355+
for c in cites[mode]:
356+
Citations().register(c)
357+
except ImportError:
358+
pass
359+
if mode == "vanilla":
360+
import tsfc.vanilla as m
361+
elif mode == "coffee":
362+
import tsfc.coffee_mode as m
363+
elif mode == "spectral":
364+
import tsfc.spectral as m
365+
elif mode == "tensor":
366+
import tsfc.tensor as m
367+
else:
368+
raise ValueError("Unknown mode: {}".format(mode))
369+
return m

0 commit comments

Comments
 (0)