Skip to content

Commit 3e15b16

Browse files
committed
use form data index as key instead of form data itself
1 parent 6a2c5b5 commit 3e15b16

File tree

1 file changed

+54
-69
lines changed

1 file changed

+54
-69
lines changed

tsfc/driver.py

Lines changed: 54 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,16 @@ class TSFCFormData(object):
3434
r"""Mimic `ufl.FormData`.
3535
3636
:arg form_data_tuple: A tuple of `ufl.FormData`s.
37+
:arg extraarg_tuple: A tuple of extra `ufl.Argument`s
38+
corresponding to form_data_tuple. These extra
39+
arguments are eventually replaced by the user with
40+
the associated functions in function_tuple after
41+
compiling UFL but before compiling gem. These
42+
arguments thus do not contribute to the rank of the form.
43+
:arg function_tuple: A tuple of functions corresponding
44+
to extraarg_tuple.
3745
:arg original_form: The form from which forms for
3846
`ufl.Formdata`s were extracted.
39-
:arg form_data_extraarg_map: A map from `ufl.FormData`s to
40-
extra `ufl.Argument`s: the user can apply arbitrary
41-
linear transformations to these `ufl.Argument`s and
42-
replace them with corresponding functions stored in
43-
`form_data_function_map`. This must happen after
44-
compiling UFL but before compiling gem. These
45-
`ufl.Arguments` thus do not contribute to the rank
46-
of the form.
47-
:arg form_data_function_map: A map from `ufl.FormData`s to
48-
functions corresponding to the face `ufl.Argument`s in
49-
`form_data_extraarg_map`.
5047
:diagonal: A flag for diagonal matrix assembly.
5148
5249
This class mimics `ufl.FormData`, but is to contain minimum
@@ -69,8 +66,8 @@ class TSFCFormData(object):
6966
+--- form_N ---- gem_N' ---- gem_N ---+
7067
7168
After preprocessing `ufl.FormData`s here:
69+
* Only essential information about the `ufl.FormData`s is retained.
7270
* TSFC can forget `ufl.FormData.original_form`,
73-
* TSFC can forget `ufl.IntegralData.enabled_coefficients`,
7471
* `KernelBuilder`s only need to deal with raw `ufl.Coefficient`s.
7572
7673
Illustration of the structures.
@@ -91,108 +88,96 @@ class TSFCFormData(object):
9188
|____0___||____1___|_ _|____M___| ||________||________| |________||
9289
|_____________________________________|
9390
"""
94-
def __init__(self, form_data_tuple, original_form, form_data_extraarg_map, form_data_function_map, diagonal):
91+
def __init__(self, form_data_tuple, extraarg_tuple, function_tuple, original_form, diagonal):
9592
arguments = set()
96-
for fd in form_data_tuple:
93+
for fd, extraarg in zip(form_data_tuple, extraarg_tuple):
9794
args = []
9895
for arg in fd.preprocessed_form.arguments():
99-
if arg not in form_data_extraarg_map[fd]:
96+
if arg not in extraarg:
10097
args.append(arg)
10198
arguments.update((tuple(args), ))
10299
if len(arguments) != 1:
103100
raise ValueError("Found inconsistent sets of arguments in `FormData`s.")
104101
self.arguments, = tuple(arguments)
105-
# Gathere all coefficients.
102+
# Gather all coefficients.
106103
# If a form contains extra arguments, those will be replaced by corresponding functions
107104
# after compiling UFL, so these functions must be included here, too.
108105
reduced_coefficients_set = set(c for fd in form_data_tuple for c in fd.reduced_coefficients)
109-
for _, val in form_data_function_map.items():
110-
reduced_coefficients_set.update(val)
106+
reduced_coefficients_set.update(chain(*function_tuple))
111107
reduced_coefficients = sorted(reduced_coefficients_set, key=lambda c: c.count())
112-
if len(form_data_tuple) == 1:
113-
self.reduced_coefficients = form_data_tuple[0].reduced_coefficients
114-
self.original_coefficient_positions = form_data_tuple[0].original_coefficient_positions
115-
self.function_replace_map = form_data_tuple[0].function_replace_map
116-
else:
117-
# Reconstruct `ufl.Coefficinet`s with count starting at 0.
118-
function_replace_map = {}
119-
for i, func in enumerate(reduced_coefficients):
120-
for fd in form_data_tuple:
121-
if func in fd.function_replace_map:
122-
coeff = fd.function_replace_map[func]
123-
new_coeff = Coefficient(coeff.ufl_function_space(), count=i)
124-
function_replace_map[func] = new_coeff
125-
break
126-
else:
127-
ufl_function_space = FunctionSpace(func.ufl_domain(), func.ufl_element())
128-
new_coeff = Coefficient(ufl_function_space, count=i)
108+
# Reconstruct `ufl.Coefficinet`s with count starting at 0.
109+
function_replace_map = {}
110+
for i, func in enumerate(reduced_coefficients):
111+
for fd in form_data_tuple:
112+
if func in fd.function_replace_map:
113+
coeff = fd.function_replace_map[func]
114+
new_coeff = Coefficient(coeff.ufl_function_space(), count=i)
129115
function_replace_map[func] = new_coeff
130-
self.reduced_coefficients = reduced_coefficients
131-
self.original_coefficient_positions = [i for i, f in enumerate(original_form.coefficients())
132-
if f in self.reduced_coefficients]
133-
self.function_replace_map = function_replace_map
116+
break
117+
else:
118+
ufl_function_space = FunctionSpace(func.ufl_domain(), func.ufl_element())
119+
new_coeff = Coefficient(ufl_function_space, count=i)
120+
function_replace_map[func] = new_coeff
121+
self.reduced_coefficients = reduced_coefficients
122+
self.original_coefficient_positions = [i for i, f in enumerate(original_form.coefficients())
123+
if f in self.reduced_coefficients]
124+
self.function_replace_map = function_replace_map
134125

135126
# Translate `ufl.IntegralData`s -> `TSFCIntegralData`.
136-
intg_data_dict = {}
137-
form_data_dict = {}
138-
for form_data in form_data_tuple:
127+
intg_data_info_dict = {}
128+
for form_data_index, form_data in enumerate(form_data_tuple):
139129
for intg_data in form_data.integral_data:
140130
domain = intg_data.domain
141131
integral_type = intg_data.integral_type
142132
subdomain_id = intg_data.subdomain_id
143133
key = (domain, integral_type, subdomain_id)
144-
# Add intg_data.
145-
intg_data_dict.setdefault(key, []).append(intg_data)
146-
# Remember which form_data this intg_data came from.
147-
form_data_dict.setdefault(key, []).append(form_data)
134+
# Add (intg_data, form_data, form_data_index).
135+
intg_data_info_dict.setdefault(key, []).append((intg_data, form_data, form_data_index))
148136
integral_data_list = []
149-
for key in intg_data_dict:
150-
intg_data_list = intg_data_dict[key]
151-
form_data_list = form_data_dict[key]
137+
for key, intg_data_info in intg_data_info_dict.items():
152138
domain, _, _ = key
153139
domain_number = original_form.domain_numbering()[domain]
154-
integral_data_list.append(TSFCIntegralData(key, intg_data_list, form_data_list,
155-
self, domain_number, form_data_function_map))
140+
integral_data_list.append(TSFCIntegralData(key, intg_data_info,
141+
self, domain_number, function_tuple))
156142
self.integral_data = tuple(integral_data_list)
157143

158144

159145
class TSFCIntegralData(object):
160146
r"""Mimics `ufl.IntegralData`.
161147
162148
:arg integral_data_key: (domain, integral_type, subdomain_id) tuple.
163-
:arg integral_data_list: A list of `ufl.IntegralData`.
164-
:arg form_data_list: A list of `ufl.FormData`.
165-
:arg tsfc_form_data: The `TSFCFormData` that will contain this
149+
:arg integral_data_info: A tuple of the lists of integral_data,
150+
form_data, and form_data_index.
151+
:arg tsfc_form_data: The `TSFCFormData` that is to contain this
166152
`TSFCIntegralData` object.
167153
:arg domain_number: The domain number associated with `domain`.
168-
:arg form_data_function_map: A map from `ufl.FormData`s to functions.
154+
:arg function_tuple: A tuple of functions.
169155
170-
This class mimics `ufl.FormData`, but:
171-
* extracts information required by TSFC.
172-
* preprocesses integrals so that `KernelBuilder`s only
173-
need to deal with raw `ufl.Coefficient`s.
156+
After preprocessing `ufl.IntegralData`s here:
157+
* Only essential information about the `ufl.IntegralData`s is retained.
158+
* TSFC can forget `ufl.IntegralData.enabled_coefficients`,
174159
"""
175-
def __init__(self, integral_data_key, integral_data_list, form_data_list, tsfc_form_data, domain_number, form_data_function_map):
160+
def __init__(self, integral_data_key, intg_data_info, tsfc_form_data, domain_number, function_tuple):
176161
self.domain, self.integral_type, self.subdomain_id = integral_data_key
177162
self.domain_number = domain_number
178163
# Gather/preprocess integrals.
179164
integrals = []
180-
_integral_to_form_data_map = {}
165+
_integral_index_to_form_data_index = []
181166
functions = set()
182-
for intg_data, form_data in zip(integral_data_list, form_data_list):
167+
for intg_data, form_data, form_data_index in intg_data_info:
183168
for integral in intg_data.integrals:
184169
integrand = integral.integrand()
185170
# Replace functions with Coefficients here.
186171
integrand = ufl.replace(integrand, tsfc_form_data.function_replace_map)
187172
new_integral = integral.reconstruct(integrand=integrand)
188173
integrals.append(new_integral)
189174
# Remember which form_data this integral is associated with.
190-
_integral_to_form_data_map[new_integral] = form_data
175+
_integral_index_to_form_data_index.append(form_data_index)
191176
# Gather functions that are enabled in this `TSFCIntegralData`.
192177
functions.update(f for f, enabled in zip(form_data.reduced_coefficients, intg_data.enabled_coefficients) if enabled)
193-
functions.update(form_data_function_map[form_data])
178+
functions.update(function_tuple[form_data_index])
194179
self.integrals = tuple(integrals)
195-
self._integral_to_form_data_map = _integral_to_form_data_map
180+
self._integral_index_to_form_data_index = _integral_index_to_form_data_index
196181
self.arguments = tsfc_form_data.arguments
197182
# This is which coefficient in the original form the
198183
# current coefficient is.
@@ -206,9 +191,9 @@ def __init__(self, integral_data_key, integral_data_list, form_data_list, tsfc_f
206191
self.coefficients = tuple(tsfc_form_data.function_replace_map[f] for f in functions)
207192
self.coefficient_numbers = tuple(tsfc_form_data.original_coefficient_positions[tsfc_form_data.reduced_coefficients.index(f)] for f in functions)
208193

209-
def integral_to_form_data(self, integral):
210-
r"""Return `ufl.FormData` which `integral` is associated with."""
211-
return self._integral_to_form_data_map[integral]
194+
def integral_index_to_form_data_index(self, integral_index):
195+
r"""Return the form data index given an integral index."""
196+
return self._integral_index_to_form_data_index[integral_index]
212197

213198

214199
def compile_form(form, prefix="form", parameters=None, interface=None, coffee=True, diagonal=False):
@@ -233,7 +218,7 @@ def compile_form(form, prefix="form", parameters=None, interface=None, coffee=Tr
233218
form_data = ufl_utils.compute_form_data(form, complex_mode=complex_mode)
234219
if interface:
235220
interface = partial(interface, function_replace_map=form_data.function_replace_map)
236-
tsfc_form_data = TSFCFormData((form_data, ), form_data.original_form, {form_data: ()}, {form_data: ()}, diagonal)
221+
tsfc_form_data = TSFCFormData((form_data, ), ((), ), ((), ), form_data.original_form, diagonal)
237222

238223
logger.info(GREEN % "compute_form_data finished in %g seconds.", time.time() - cpu_time)
239224

0 commit comments

Comments
 (0)