|
| 1 | +import collections |
| 2 | +import string |
| 3 | +import operator |
| 4 | +from functools import reduce |
| 5 | +from itertools import chain |
| 6 | + |
1 | 7 | import numpy |
| 8 | +from numpy import asarray |
2 | 9 |
|
3 | 10 | import coffee.base as coffee |
4 | 11 |
|
| 12 | +from ufl.utils.sequences import max_degree |
| 13 | + |
5 | 14 | import gem |
6 | 15 |
|
7 | 16 | from gem.utils import cached_property |
| 17 | +import gem.impero_utils as impero_utils |
8 | 18 |
|
| 19 | +from tsfc import fem, ufl_utils |
9 | 20 | from tsfc.kernel_interface import KernelInterface |
| 21 | +from tsfc.finatinterface import as_fiat_cell |
| 22 | +from tsfc.logging import logger |
| 23 | + |
| 24 | +from FIAT.reference_element import TensorProductCell |
| 25 | + |
| 26 | +from finat.quadrature import AbstractQuadratureRule, make_quadrature |
10 | 27 |
|
11 | 28 |
|
12 | 29 | class KernelBuilderBase(KernelInterface): |
@@ -107,3 +124,246 @@ def register_requirements(self, ir): |
107 | 124 | """ |
108 | 125 | # Nothing is required by default |
109 | 126 | pass |
| 127 | + |
| 128 | + |
| 129 | +class KernelBuilderMixin(object): |
| 130 | + """Mixin for KernelBuilder classes.""" |
| 131 | + |
| 132 | + def compile_ufl(self, integrand, params, argument_multiindices=None): |
| 133 | + """Compile UFL integrand. |
| 134 | +
|
| 135 | + :arg integrand: UFL integrand. |
| 136 | + :arg params: a dict of parameters containing quadrature info. |
| 137 | + :kwarg argument_multiindices: multiindices to use to index |
| 138 | + arguments contained in the given integrand. If None, the |
| 139 | + "true" argument multiindices (`self.argument_multiindices`) |
| 140 | + used in the `self.return_variables` are used. |
| 141 | +
|
| 142 | + .. note:: |
| 143 | + Problem solving environments can pass any multiindices to be |
| 144 | + used in the returned gem expression, e.g., |
| 145 | + `self.argument_multiindices_dummy`. They are then responsible |
| 146 | + for applying appropriate operations to the returned gem |
| 147 | + expression so that the indices in the resulting gem expression |
| 148 | + match those in `self.return_variables`. |
| 149 | + """ |
| 150 | + # Split Coefficients |
| 151 | + if self.coefficient_split: |
| 152 | + integrand = ufl_utils.split_coefficients(integrand, self.coefficient_split) |
| 153 | + # Compile: ufl -> gem |
| 154 | + functions = list(self.arguments) + [self.coordinate(self.integral_data.domain)] + list(self.integral_data.coefficients) |
| 155 | + _set_quad_rule(params, self.integral_data.domain.ufl_cell(), self.integral_data.integral_type, functions) |
| 156 | + quad_rule = params["quadrature_rule"] |
| 157 | + config = self.fem_config.copy() |
| 158 | + config.update(quadrature_rule=quad_rule) |
| 159 | + config['argument_multiindices'] = argument_multiindices or self.argument_multiindices |
| 160 | + expressions = fem.compile_ufl(integrand, |
| 161 | + interior_facet=self.interior_facet, |
| 162 | + **config) |
| 163 | + self.quadrature_indices.extend(quad_rule.point_set.indices) |
| 164 | + return expressions |
| 165 | + |
| 166 | + def construct_integrals(self, expressions, params): |
| 167 | + mode = pick_mode(params["mode"]) |
| 168 | + return mode.Integrals(expressions, |
| 169 | + params["quadrature_rule"].point_set.indices, |
| 170 | + self.argument_multiindices, |
| 171 | + params) |
| 172 | + |
| 173 | + def stash_integrals(self, reps, params): |
| 174 | + mode = pick_mode(params["mode"]) |
| 175 | + mode_irs = self.mode_irs |
| 176 | + mode_irs.setdefault(mode, collections.OrderedDict()) |
| 177 | + for var, rep in zip(self.return_variables, reps): |
| 178 | + mode_irs[mode].setdefault(var, []).append(rep) |
| 179 | + |
| 180 | + def compile_gem(self): |
| 181 | + # Finalise mode representations into a set of assignments |
| 182 | + mode_irs = self.mode_irs |
| 183 | + index_cache = self.fem_config['index_cache'] |
| 184 | + |
| 185 | + assignments = [] |
| 186 | + for mode, var_reps in mode_irs.items(): |
| 187 | + assignments.extend(mode.flatten(var_reps.items(), index_cache)) |
| 188 | + |
| 189 | + if assignments: |
| 190 | + return_variables, expressions = zip(*assignments) |
| 191 | + else: |
| 192 | + return_variables = [] |
| 193 | + expressions = [] |
| 194 | + |
| 195 | + # Need optimised roots |
| 196 | + options = dict(reduce(operator.and_, |
| 197 | + [mode.finalise_options.items() |
| 198 | + for mode in mode_irs.keys()])) |
| 199 | + expressions = impero_utils.preprocess_gem(expressions, **options) |
| 200 | + |
| 201 | + # Let the kernel interface inspect the optimised IR to register |
| 202 | + # what kind of external data is required (e.g., cell orientations, |
| 203 | + # cell sizes, etc.). |
| 204 | + oriented, needs_cell_sizes, tabulations = self.register_requirements(expressions) |
| 205 | + |
| 206 | + # Construct ImperoC |
| 207 | + assignments = list(zip(return_variables, expressions)) |
| 208 | + index_ordering = get_index_ordering(self.quadrature_indices, return_variables) |
| 209 | + try: |
| 210 | + impero_c = impero_utils.compile_gem(assignments, index_ordering, remove_zeros=True) |
| 211 | + except impero_utils.NoopError: |
| 212 | + impero_c = None |
| 213 | + return impero_c, oriented, needs_cell_sizes, tabulations |
| 214 | + |
| 215 | + @cached_property |
| 216 | + def argument_multiindices_dummy(self): |
| 217 | + return tuple(tuple(gem.Index(extent=a.extent) for a in arg) |
| 218 | + for arg in self.argument_multiindices) |
| 219 | + |
| 220 | + @cached_property |
| 221 | + def fem_config(self): |
| 222 | + # Map from UFL FiniteElement objects to multiindices. This is |
| 223 | + # so we reuse Index instances when evaluating the same coefficient |
| 224 | + # multiple times with the same table. |
| 225 | + # |
| 226 | + # We also use the same dict for the unconcatenate index cache, |
| 227 | + # which maps index objects to tuples of multiindices. These two |
| 228 | + # caches shall never conflict as their keys have different types |
| 229 | + # (UFL finite elements vs. GEM index objects). |
| 230 | + # |
| 231 | + # -> fem_config['index_cache'] |
| 232 | + integral_type = self.integral_data.integral_type |
| 233 | + cell = self.integral_data.domain.ufl_cell() |
| 234 | + fiat_cell = as_fiat_cell(cell) |
| 235 | + integration_dim, entity_ids = lower_integral_type(fiat_cell, integral_type) |
| 236 | + return dict(interface=self, |
| 237 | + ufl_cell=cell, |
| 238 | + integral_type=integral_type, |
| 239 | + integration_dim=integration_dim, |
| 240 | + entity_ids=entity_ids, |
| 241 | + index_cache={}, |
| 242 | + scalar_type=self.fem_scalar_type) |
| 243 | + |
| 244 | + |
| 245 | +def get_index_ordering(quadrature_indices, return_variables): |
| 246 | + split_argument_indices = tuple(chain(*[var.index_ordering() |
| 247 | + for var in return_variables])) |
| 248 | + return tuple(quadrature_indices) + split_argument_indices |
| 249 | + |
| 250 | + |
| 251 | +def get_index_names(quadrature_indices, argument_multiindices, index_cache): |
| 252 | + index_names = [] |
| 253 | + |
| 254 | + def name_index(index, name): |
| 255 | + index_names.append((index, name)) |
| 256 | + if index in index_cache: |
| 257 | + for multiindex, suffix in zip(index_cache[index], |
| 258 | + string.ascii_lowercase): |
| 259 | + name_multiindex(multiindex, name + suffix) |
| 260 | + |
| 261 | + def name_multiindex(multiindex, name): |
| 262 | + if len(multiindex) == 1: |
| 263 | + name_index(multiindex[0], name) |
| 264 | + else: |
| 265 | + for i, index in enumerate(multiindex): |
| 266 | + name_index(index, name + str(i)) |
| 267 | + |
| 268 | + name_multiindex(quadrature_indices, 'ip') |
| 269 | + for multiindex, name in zip(argument_multiindices, ['j', 'k']): |
| 270 | + name_multiindex(multiindex, name) |
| 271 | + return index_names |
| 272 | + |
| 273 | + |
| 274 | +def _set_quad_rule(params, cell, integral_type, functions): |
| 275 | + # Check if the integral has a quad degree attached, otherwise use |
| 276 | + # the estimated polynomial degree attached by compute_form_data |
| 277 | + try: |
| 278 | + quadrature_degree = params["quadrature_degree"] |
| 279 | + except KeyError: |
| 280 | + quadrature_degree = params["estimated_polynomial_degree"] |
| 281 | + function_degrees = [f.ufl_function_space().ufl_element().degree() for f in functions] |
| 282 | + if all((asarray(quadrature_degree) > 10 * asarray(degree)).all() |
| 283 | + for degree in function_degrees): |
| 284 | + logger.warning("Estimated quadrature degree %s more " |
| 285 | + "than tenfold greater than any " |
| 286 | + "argument/coefficient degree (max %s)", |
| 287 | + quadrature_degree, max_degree(function_degrees)) |
| 288 | + if params.get("quadrature_rule") == "default": |
| 289 | + del params["quadrature_rule"] |
| 290 | + try: |
| 291 | + quad_rule = params["quadrature_rule"] |
| 292 | + except KeyError: |
| 293 | + fiat_cell = as_fiat_cell(cell) |
| 294 | + integration_dim, _ = lower_integral_type(fiat_cell, integral_type) |
| 295 | + integration_cell = fiat_cell.construct_subelement(integration_dim) |
| 296 | + quad_rule = make_quadrature(integration_cell, quadrature_degree) |
| 297 | + params["quadrature_rule"] = quad_rule |
| 298 | + |
| 299 | + if not isinstance(quad_rule, AbstractQuadratureRule): |
| 300 | + raise ValueError("Expected to find a QuadratureRule object, not a %s" % |
| 301 | + type(quad_rule)) |
| 302 | + |
| 303 | + |
| 304 | +def lower_integral_type(fiat_cell, integral_type): |
| 305 | + """Lower integral type into the dimension of the integration |
| 306 | + subentity and a list of entity numbers for that dimension. |
| 307 | +
|
| 308 | + :arg fiat_cell: FIAT reference cell |
| 309 | + :arg integral_type: integral type (string) |
| 310 | + """ |
| 311 | + vert_facet_types = ['exterior_facet_vert', 'interior_facet_vert'] |
| 312 | + horiz_facet_types = ['exterior_facet_bottom', 'exterior_facet_top', 'interior_facet_horiz'] |
| 313 | + |
| 314 | + dim = fiat_cell.get_dimension() |
| 315 | + if integral_type == 'cell': |
| 316 | + integration_dim = dim |
| 317 | + elif integral_type in ['exterior_facet', 'interior_facet']: |
| 318 | + if isinstance(fiat_cell, TensorProductCell): |
| 319 | + raise ValueError("{} integral cannot be used with a TensorProductCell; need to distinguish between vertical and horizontal contributions.".format(integral_type)) |
| 320 | + integration_dim = dim - 1 |
| 321 | + elif integral_type == 'vertex': |
| 322 | + integration_dim = 0 |
| 323 | + elif integral_type in vert_facet_types + horiz_facet_types: |
| 324 | + # Extrusion case |
| 325 | + if not isinstance(fiat_cell, TensorProductCell): |
| 326 | + raise ValueError("{} integral requires a TensorProductCell.".format(integral_type)) |
| 327 | + basedim, extrdim = dim |
| 328 | + assert extrdim == 1 |
| 329 | + |
| 330 | + if integral_type in vert_facet_types: |
| 331 | + integration_dim = (basedim - 1, 1) |
| 332 | + elif integral_type in horiz_facet_types: |
| 333 | + integration_dim = (basedim, 0) |
| 334 | + else: |
| 335 | + raise NotImplementedError("integral type %s not supported" % integral_type) |
| 336 | + |
| 337 | + if integral_type == 'exterior_facet_bottom': |
| 338 | + entity_ids = [0] |
| 339 | + elif integral_type == 'exterior_facet_top': |
| 340 | + entity_ids = [1] |
| 341 | + else: |
| 342 | + entity_ids = list(range(len(fiat_cell.get_topology()[integration_dim]))) |
| 343 | + |
| 344 | + return integration_dim, entity_ids |
| 345 | + |
| 346 | + |
| 347 | +def pick_mode(mode): |
| 348 | + "Return one of the specialized optimisation modules from a mode string." |
| 349 | + try: |
| 350 | + from firedrake_citations import Citations |
| 351 | + cites = {"vanilla": ("Homolya2017", ), |
| 352 | + "coffee": ("Luporini2016", "Homolya2017", ), |
| 353 | + "spectral": ("Luporini2016", "Homolya2017", "Homolya2017a"), |
| 354 | + "tensor": ("Kirby2006", "Homolya2017", )} |
| 355 | + for c in cites[mode]: |
| 356 | + Citations().register(c) |
| 357 | + except ImportError: |
| 358 | + pass |
| 359 | + if mode == "vanilla": |
| 360 | + import tsfc.vanilla as m |
| 361 | + elif mode == "coffee": |
| 362 | + import tsfc.coffee_mode as m |
| 363 | + elif mode == "spectral": |
| 364 | + import tsfc.spectral as m |
| 365 | + elif mode == "tensor": |
| 366 | + import tsfc.tensor as m |
| 367 | + else: |
| 368 | + raise ValueError("Unknown mode: {}".format(mode)) |
| 369 | + return m |
0 commit comments