55import itertools
66from functools import singledispatch
77
8+ import gem
89import numpy
9-
1010import ufl
11- from ufl .corealg .map_dag import map_expr_dag , map_expr_dags
12- from ufl .corealg .multifunction import MultiFunction
13- from ufl .classes import (
14- Argument , CellCoordinate , CellEdgeVectors , CellFacetJacobian ,
15- CellOrientation , CellOrigin , CellVertices , CellVolume , Coefficient ,
16- FacetArea , FacetCoordinate , GeometricQuantity , Jacobian ,
17- JacobianDeterminant , NegativeRestricted , QuadratureWeight ,
18- PositiveRestricted , ReferenceCellVolume , ReferenceCellEdgeVectors ,
19- ReferenceFacetVolume , ReferenceNormal , SpatialCoordinate
20- )
21- from ufl .domain import extract_unique_domain
22-
23- from FIAT .reference_element import make_affine_mapping
24- from FIAT .reference_element import UFCSimplex
25-
26- import gem
11+ from FIAT .reference_element import UFCSimplex , make_affine_mapping
12+ from finat .physically_mapped import (NeedsCoordinateMappingElement ,
13+ PhysicalGeometry )
14+ from finat .point_set import PointSet , PointSingleton
15+ from finat .quadrature import make_quadrature
2716from gem .node import traversal
28- from gem .optimise import ffc_rounding , constant_fold_zero
17+ from gem .optimise import constant_fold_zero , ffc_rounding
2918from gem .unconcatenate import unconcatenate
3019from gem .utils import cached_property
31-
32- from finat .physically_mapped import PhysicalGeometry , NeedsCoordinateMappingElement
33- from finat .point_set import PointSet , PointSingleton
34- from finat .quadrature import make_quadrature
20+ from ufl .classes import (Argument , CellCoordinate , CellEdgeVectors ,
21+ CellFacetJacobian , CellOrientation , CellOrigin ,
22+ CellVertices , CellVolume , Coefficient , FacetArea ,
23+ FacetCoordinate , GeometricQuantity , Jacobian ,
24+ JacobianDeterminant , NegativeRestricted ,
25+ PositiveRestricted , QuadratureWeight ,
26+ ReferenceCellEdgeVectors , ReferenceCellVolume ,
27+ ReferenceFacetVolume , ReferenceNormal ,
28+ SpatialCoordinate )
29+ from ufl .corealg .map_dag import map_expr_dag , map_expr_dags
30+ from ufl .corealg .multifunction import MultiFunction
31+ from ufl .domain import extract_unique_domain
3532
3633from tsfc import ufl2gem
3734from tsfc .finatinterface import as_fiat_cell , create_element
4037 construct_modified_terminal )
4138from tsfc .parameters import is_complex
4239from tsfc .ufl_utils import (ModifiedTerminalMixin , PickRestriction ,
43- entity_avg , one_times , simplify_abs ,
44- preprocess_expression , TSFCConstantMixin )
40+ TSFCConstantMixin , entity_avg , one_times ,
41+ preprocess_expression , simplify_abs )
4542
4643
4744class ContextBase (ProxyKernelInterface ):
@@ -175,31 +172,35 @@ def detJ_at(self, point):
175172 return map_expr_dag (context .translator , expr )
176173
177174 def reference_normals (self ):
178- if not (isinstance (self .interface .fiat_cell , UFCSimplex ) and
179- self .interface .fiat_cell .get_spatial_dimension () == 2 ):
180- raise NotImplementedError ("Only works for triangles for now" )
181- return gem .Literal (numpy .asarray ([self .interface .fiat_cell .compute_normal (i ) for i in range (3 )]))
175+ cell = self .interface .fiat_cell
176+ sd = cell .get_spatial_dimension ()
177+ num_faces = len (cell .get_topology ()[sd - 1 ])
178+
179+ return gem .Literal (numpy .asarray ([cell .compute_normal (i ) for i in range (num_faces )]))
182180
183181 def reference_edge_tangents (self ):
184- return gem .Literal (numpy .asarray ([self .interface .fiat_cell .compute_edge_tangent (i ) for i in range (3 )]))
182+ cell = self .interface .fiat_cell
183+ num_edges = len (cell .get_topology ()[1 ])
184+ return gem .Literal (numpy .asarray ([cell .compute_edge_tangent (i ) for i in range (num_edges )]))
185185
186186 def physical_tangents (self ):
187- if not (isinstance (self .interface .fiat_cell , UFCSimplex ) and
188- self .interface .fiat_cell .get_spatial_dimension () == 2 ):
189- raise NotImplementedError ("Only works for triangles for now" )
190-
191- rts = [self .interface .fiat_cell .compute_tangents (1 , f )[0 ] for f in range (3 )]
192- jac = self .jacobian_at ([1 / 3 , 1 / 3 ])
193-
187+ cell = self .interface .fiat_cell
188+ sd = cell .get_spatial_dimension ()
189+ num_edges = len (cell .get_topology ()[1 ])
194190 els = self .physical_edge_lengths ()
191+ rts = gem .ListTensor ([cell .compute_tangents (1 , i )[0 ] / els [i ] for i in range (num_edges )])
192+ jac = self .jacobian_at (cell .make_points (sd , 0 , sd + 1 )[0 ])
195193
196- return gem .ListTensor ([[(jac [0 , 0 ]* rts [i ][0 ] + jac [0 , 1 ]* rts [i ][1 ]) / els [i ],
197- (jac [1 , 0 ]* rts [i ][0 ] + jac [1 , 1 ]* rts [i ][1 ]) / els [i ]]
198- for i in range (3 )])
194+ return rts @ jac .T
199195
200196 def physical_normals (self ):
197+ cell = self .interface .fiat_cell
198+ if not (isinstance (cell , UFCSimplex ) and cell .get_dimension () == 2 ):
199+ raise NotImplementedError ("Can't do physical normals on that cell yet" )
200+
201+ num_edges = len (cell .get_topology ()[1 ])
201202 pts = self .physical_tangents ()
202- return gem .ListTensor ([[pts [i , 1 ], - 1 * pts [i , 0 ]] for i in range (3 )])
203+ return gem .ListTensor ([[pts [i , 1 ], - 1 * pts [i , 0 ]] for i in range (num_edges )])
203204
204205 def physical_edge_lengths (self ):
205206 expr = ufl .classes .CellEdgeVectors (extract_unique_domain (self .mt .terminal ))
@@ -208,8 +209,11 @@ def physical_edge_lengths(self):
208209 elif self .mt .restriction == '-' :
209210 expr = NegativeRestricted (expr )
210211
211- expr = ufl .as_vector ([ufl .sqrt (ufl .dot (expr [i , :], expr [i , :])) for i in range (3 )])
212- config = {"point_set" : PointSingleton ([1 / 3 , 1 / 3 ])}
212+ cell = self .interface .fiat_cell
213+ sd = cell .get_spatial_dimension ()
214+ num_edges = len (cell .get_topology ()[1 ])
215+ expr = ufl .as_vector ([ufl .sqrt (ufl .dot (expr [i , :], expr [i , :])) for i in range (num_edges )])
216+ config = {"point_set" : PointSingleton (cell .make_points (sd , 0 , sd + 1 )[0 ])}
213217 config .update (self .config )
214218 context = PointSetContext (** config )
215219 expr = self .preprocess (expr , context )
@@ -443,7 +447,8 @@ def callback(facet_i):
443447
444448@translate .register (ReferenceCellEdgeVectors )
445449def translate_reference_cell_edge_vectors (terminal , mt , ctx ):
446- from FIAT .reference_element import TensorProductCell as fiat_TensorProductCell
450+ from FIAT .reference_element import \
451+ TensorProductCell as fiat_TensorProductCell
447452 fiat_cell = ctx .fiat_cell
448453 if isinstance (fiat_cell , fiat_TensorProductCell ):
449454 raise NotImplementedError ("ReferenceCellEdgeVectors not implemented on TensorProductElements yet" )
0 commit comments