Skip to content

Commit 03f59a9

Browse files
committed
loopy: Use inferred type for output GlobalArg
Only in the case of 0-forms, where we can control the allocated scalar type.
1 parent 54e2978 commit 03f59a9

File tree

3 files changed

+32
-11
lines changed

3 files changed

+32
-11
lines changed

tsfc/kernel_interface/firedrake.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def make_builder(*args, **kwargs):
2727
class Kernel(object):
2828
__slots__ = ("ast", "integral_type", "oriented", "subdomain_id",
2929
"domain_number", "needs_cell_sizes", "tabulations", "quadrature_rule",
30-
"coefficient_numbers", "__weakref__")
30+
"return_dtype", "coefficient_numbers", "__weakref__")
3131
"""A compiled Kernel object.
3232
3333
:kwarg ast: The COFFEE ast for the kernel.
@@ -40,12 +40,14 @@ class Kernel(object):
4040
:kwarg coefficient_numbers: A list of which coefficients from the
4141
form the kernel needs.
4242
:kwarg quadrature_rule: The finat quadrature rule used to generate this kernel
43+
:kwarg return_dtype: numpy dtype of the return value.
4344
:kwarg tabulations: The runtime tabulations this kernel requires
4445
:kwarg needs_cell_sizes: Does the kernel require cell sizes.
4546
"""
4647
def __init__(self, ast=None, integral_type=None, oriented=False,
4748
subdomain_id=None, domain_number=None, quadrature_rule=None,
4849
coefficient_numbers=(),
50+
return_dtype=None,
4951
needs_cell_sizes=False):
5052
# Defaults
5153
self.ast = ast
@@ -55,6 +57,7 @@ def __init__(self, ast=None, integral_type=None, oriented=False,
5557
self.subdomain_id = subdomain_id
5658
self.coefficient_numbers = coefficient_numbers
5759
self.needs_cell_sizes = needs_cell_sizes
60+
self.return_dtype = return_dtype
5861
super(Kernel, self).__init__()
5962

6063

tsfc/kernel_interface/firedrake_loopy.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def make_builder(*args, **kwargs):
2727
class Kernel(object):
2828
__slots__ = ("ast", "integral_type", "oriented", "subdomain_id",
2929
"domain_number", "needs_cell_sizes", "tabulations", "quadrature_rule",
30-
"coefficient_numbers", "__weakref__")
30+
"return_dtype", "coefficient_numbers", "__weakref__")
3131
"""A compiled Kernel object.
3232
3333
:kwarg ast: The loopy kernel object.
@@ -40,12 +40,14 @@ class Kernel(object):
4040
:kwarg coefficient_numbers: A list of which coefficients from the
4141
form the kernel needs.
4242
:kwarg quadrature_rule: The finat quadrature rule used to generate this kernel
43+
:kwarg return_dtype: numpy dtype of the return value.
4344
:kwarg tabulations: The runtime tabulations this kernel requires
4445
:kwarg needs_cell_sizes: Does the kernel require cell sizes.
4546
"""
4647
def __init__(self, ast=None, integral_type=None, oriented=False,
4748
subdomain_id=None, domain_number=None, quadrature_rule=None,
4849
coefficient_numbers=(),
50+
return_dtype=None,
4951
needs_cell_sizes=False):
5052
# Defaults
5153
self.ast = ast
@@ -55,6 +57,7 @@ def __init__(self, ast=None, integral_type=None, oriented=False,
5557
self.subdomain_id = subdomain_id
5658
self.coefficient_numbers = coefficient_numbers
5759
self.needs_cell_sizes = needs_cell_sizes
60+
self.return_dtype = return_dtype
5861
super(Kernel, self).__init__()
5962

6063

@@ -164,8 +167,8 @@ def construct_kernel(self, return_arg, impero_c, precision, index_names):
164167
for name_, shape in self.tabulations:
165168
args.append(lp.GlobalArg(name_, dtype=self.scalar_type, shape=shape))
166169

167-
loopy_kernel = generate_loopy(impero_c, args, precision, self.scalar_type,
168-
"expression_kernel", index_names)
170+
loopy_kernel, _ = generate_loopy(impero_c, args, precision, self.scalar_type,
171+
"expression_kernel", index_names, ignore_return_type=True)
169172
return ExpressionKernel(loopy_kernel, self.oriented, self.cell_sizes,
170173
self.coefficients, self.tabulations)
171174

@@ -207,6 +210,7 @@ def set_arguments(self, arguments, multiindices):
207210
:arg multiindices: GEM argument multiindices
208211
:returns: GEM expression representing the return variable
209212
"""
213+
self.rank = len(arguments)
210214
self.local_tensor, expressions = prepare_arguments(
211215
arguments, multiindices, self.scalar_type, interior_facet=self.interior_facet,
212216
diagonal=self.diagonal)
@@ -277,7 +281,11 @@ def construct_kernel(self, name, impero_c, precision, index_names, quadrature_ru
277281
:returns: :class:`Kernel` object
278282
"""
279283

280-
args = [self.local_tensor, self.coordinates_arg]
284+
ignore_return_type = self.rank > 0
285+
if ignore_return_type:
286+
args = [self.local_tensor, self.coordinates_arg]
287+
else:
288+
args = [self.coordinates_arg]
281289
if self.kernel.oriented:
282290
args.append(self.cell_orientations_loopy_arg)
283291
if self.kernel.needs_cell_sizes:
@@ -292,8 +300,11 @@ def construct_kernel(self, name, impero_c, precision, index_names, quadrature_ru
292300
args.append(lp.GlobalArg(name_, dtype=self.scalar_type, shape=shape))
293301

294302
self.kernel.quadrature_rule = quadrature_rule
295-
self.kernel.ast = generate_loopy(impero_c, args, precision,
296-
self.scalar_type, name, index_names)
303+
ast, dtype = generate_loopy(impero_c, args, precision,
304+
self.scalar_type, name, index_names,
305+
ignore_return_type=ignore_return_type)
306+
self.kernel.ast = ast
307+
self.kernel.return_dtype = dtype
297308
return self.kernel
298309

299310
def construct_empty_kernel(self, name):

tsfc/loopy.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,15 +186,17 @@ def active_indices(mapping, ctx):
186186
ctx.active_indices.pop(key)
187187

188188

189-
def generate(impero_c, args, precision, scalar_type, kernel_name="loopy_kernel", index_names=[]):
189+
def generate(impero_c, args, precision, scalar_type, kernel_name="loopy_kernel", index_names=[],
190+
ignore_return_type=True):
190191
"""Generates loopy code.
191192
192193
:arg impero_c: ImperoC tuple with Impero AST and other data
193194
:arg args: list of loopy.GlobalArgs
194195
:arg precision: floating-point precision for printing
195-
:arg scalar_type: type of scalars as C typename string
196+
:arg scalar_type: type of scalars as numpy dtype
196197
:arg kernel_name: function name of the kernel
197198
:arg index_names: pre-assigned index names
199+
:arg ignore_return_type: Ignore inferred return type from impero_c?
198200
:returns: loopy kernel
199201
"""
200202
ctx = LoopyContext()
@@ -205,7 +207,12 @@ def generate(impero_c, args, precision, scalar_type, kernel_name="loopy_kernel",
205207
ctx.epsilon = 10.0 ** (-precision)
206208

207209
# Create arguments
208-
data = list(args)
210+
if ignore_return_type:
211+
return_dtype = scalar_type
212+
data = list(args)
213+
else:
214+
A, return_dtype = impero_c.return_variable
215+
data = [lp.GlobalArg(A.name, shape=A.shape, dtype=return_dtype)] + list(args)
209216
for i, (temp, dtype) in enumerate(assign_dtypes(impero_c.temporaries, scalar_type)):
210217
name = "t%d" % i
211218
if isinstance(temp, gem.Constant):
@@ -240,7 +247,7 @@ def generate(impero_c, args, precision, scalar_type, kernel_name="loopy_kernel",
240247
insn_new.append(insn.copy(priority=len(knl.instructions) - i))
241248
knl = knl.copy(instructions=insn_new)
242249

243-
return knl
250+
return knl, return_dtype
244251

245252

246253
@singledispatch

0 commit comments

Comments
 (0)