@@ -27,7 +27,7 @@ def make_builder(*args, **kwargs):
2727class Kernel (object ):
2828 __slots__ = ("ast" , "integral_type" , "oriented" , "subdomain_id" ,
2929 "domain_number" , "needs_cell_sizes" , "tabulations" , "quadrature_rule" ,
30- "coefficient_numbers" , "__weakref__" )
30+ "return_dtype" , " coefficient_numbers" , "__weakref__" )
3131 """A compiled Kernel object.
3232
3333 :kwarg ast: The loopy kernel object.
@@ -40,12 +40,14 @@ class Kernel(object):
4040 :kwarg coefficient_numbers: A list of which coefficients from the
4141 form the kernel needs.
4242 :kwarg quadrature_rule: The finat quadrature rule used to generate this kernel
43+ :kwarg return_dtype: numpy dtype of the return value.
4344 :kwarg tabulations: The runtime tabulations this kernel requires
4445 :kwarg needs_cell_sizes: Does the kernel require cell sizes.
4546 """
4647 def __init__ (self , ast = None , integral_type = None , oriented = False ,
4748 subdomain_id = None , domain_number = None , quadrature_rule = None ,
4849 coefficient_numbers = (),
50+ return_dtype = None ,
4951 needs_cell_sizes = False ):
5052 # Defaults
5153 self .ast = ast
@@ -55,6 +57,7 @@ def __init__(self, ast=None, integral_type=None, oriented=False,
5557 self .subdomain_id = subdomain_id
5658 self .coefficient_numbers = coefficient_numbers
5759 self .needs_cell_sizes = needs_cell_sizes
60+ self .return_dtype = return_dtype
5861 super (Kernel , self ).__init__ ()
5962
6063
@@ -164,8 +167,8 @@ def construct_kernel(self, return_arg, impero_c, precision, index_names):
164167 for name_ , shape in self .tabulations :
165168 args .append (lp .GlobalArg (name_ , dtype = self .scalar_type , shape = shape ))
166169
167- loopy_kernel = generate_loopy (impero_c , args , precision , self .scalar_type ,
168- "expression_kernel" , index_names )
170+ loopy_kernel , _ = generate_loopy (impero_c , args , precision , self .scalar_type ,
171+ "expression_kernel" , index_names , ignore_return_type = True )
169172 return ExpressionKernel (loopy_kernel , self .oriented , self .cell_sizes ,
170173 self .coefficients , self .tabulations )
171174
@@ -207,6 +210,7 @@ def set_arguments(self, arguments, multiindices):
207210 :arg multiindices: GEM argument multiindices
208211 :returns: GEM expression representing the return variable
209212 """
213+ self .rank = len (arguments )
210214 self .local_tensor , expressions = prepare_arguments (
211215 arguments , multiindices , self .scalar_type , interior_facet = self .interior_facet ,
212216 diagonal = self .diagonal )
@@ -277,7 +281,11 @@ def construct_kernel(self, name, impero_c, precision, index_names, quadrature_ru
277281 :returns: :class:`Kernel` object
278282 """
279283
280- args = [self .local_tensor , self .coordinates_arg ]
284+ ignore_return_type = self .rank > 0
285+ if ignore_return_type :
286+ args = [self .local_tensor , self .coordinates_arg ]
287+ else :
288+ args = [self .coordinates_arg ]
281289 if self .kernel .oriented :
282290 args .append (self .cell_orientations_loopy_arg )
283291 if self .kernel .needs_cell_sizes :
@@ -292,8 +300,11 @@ def construct_kernel(self, name, impero_c, precision, index_names, quadrature_ru
292300 args .append (lp .GlobalArg (name_ , dtype = self .scalar_type , shape = shape ))
293301
294302 self .kernel .quadrature_rule = quadrature_rule
295- self .kernel .ast = generate_loopy (impero_c , args , precision ,
296- self .scalar_type , name , index_names )
303+ ast , dtype = generate_loopy (impero_c , args , precision ,
304+ self .scalar_type , name , index_names ,
305+ ignore_return_type = ignore_return_type )
306+ self .kernel .ast = ast
307+ self .kernel .return_dtype = dtype
297308 return self .kernel
298309
299310 def construct_empty_kernel (self , name ):
0 commit comments