@@ -34,19 +34,16 @@ class TSFCFormData(object):
3434 r"""Mimic `ufl.FormData`.
3535
3636 :arg form_data_tuple: A tuple of `ufl.FormData`s.
37+ :arg extraarg_tuple: A tuple of extra `ufl.Argument`s
38+ corresponding to form_data_tuple. These extra
39+ arguments are eventually replaced by the user with
40+ the associated functions in function_tuple after
41+ compiling UFL but before compiling gem. These
42+ arguments thus do not contribute to the rank of the form.
43+ :arg function_tuple: A tuple of functions corresponding
44+ to extraarg_tuple.
3745 :arg original_form: The form from which forms for
3846 `ufl.Formdata`s were extracted.
39- :arg form_data_extraarg_map: A map from `ufl.FormData`s to
40- extra `ufl.Argument`s: the user can apply arbitrary
41- linear transformations to these `ufl.Argument`s and
42- replace them with corresponding functions stored in
43- `form_data_function_map`. This must happen after
44- compiling UFL but before compiling gem. These
45- `ufl.Arguments` thus do not contribute to the rank
46- of the form.
47- :arg form_data_function_map: A map from `ufl.FormData`s to
48- functions corresponding to the face `ufl.Argument`s in
49- `form_data_extraarg_map`.
5047 :diagonal: A flag for diagonal matrix assembly.
5148
5249 This class mimics `ufl.FormData`, but is to contain minimum
@@ -69,8 +66,8 @@ class TSFCFormData(object):
6966 +--- form_N ---- gem_N' ---- gem_N ---+
7067
7168 After preprocessing `ufl.FormData`s here:
69+ * Only essential information about the `ufl.FormData`s is retained.
7270 * TSFC can forget `ufl.FormData.original_form`,
73- * TSFC can forget `ufl.IntegralData.enabled_coefficients`,
7471 * `KernelBuilder`s only need to deal with raw `ufl.Coefficient`s.
7572
7673 Illustration of the structures.
@@ -91,108 +88,96 @@ class TSFCFormData(object):
9188 |____0___||____1___|_ _|____M___| ||________||________| |________||
9289 |_____________________________________|
9390 """
94- def __init__ (self , form_data_tuple , original_form , form_data_extraarg_map , form_data_function_map , diagonal ):
91+ def __init__ (self , form_data_tuple , extraarg_tuple , function_tuple , original_form , diagonal ):
9592 arguments = set ()
96- for fd in form_data_tuple :
93+ for fd , extraarg in zip ( form_data_tuple , extraarg_tuple ) :
9794 args = []
9895 for arg in fd .preprocessed_form .arguments ():
99- if arg not in form_data_extraarg_map [ fd ] :
96+ if arg not in extraarg :
10097 args .append (arg )
10198 arguments .update ((tuple (args ), ))
10299 if len (arguments ) != 1 :
103100 raise ValueError ("Found inconsistent sets of arguments in `FormData`s." )
104101 self .arguments , = tuple (arguments )
105- # Gathere all coefficients.
102+ # Gather all coefficients.
106103 # If a form contains extra arguments, those will be replaced by corresponding functions
107104 # after compiling UFL, so these functions must be included here, too.
108105 reduced_coefficients_set = set (c for fd in form_data_tuple for c in fd .reduced_coefficients )
109- for _ , val in form_data_function_map .items ():
110- reduced_coefficients_set .update (val )
106+ reduced_coefficients_set .update (chain (* function_tuple ))
111107 reduced_coefficients = sorted (reduced_coefficients_set , key = lambda c : c .count ())
112- if len (form_data_tuple ) == 1 :
113- self .reduced_coefficients = form_data_tuple [0 ].reduced_coefficients
114- self .original_coefficient_positions = form_data_tuple [0 ].original_coefficient_positions
115- self .function_replace_map = form_data_tuple [0 ].function_replace_map
116- else :
117- # Reconstruct `ufl.Coefficinet`s with count starting at 0.
118- function_replace_map = {}
119- for i , func in enumerate (reduced_coefficients ):
120- for fd in form_data_tuple :
121- if func in fd .function_replace_map :
122- coeff = fd .function_replace_map [func ]
123- new_coeff = Coefficient (coeff .ufl_function_space (), count = i )
124- function_replace_map [func ] = new_coeff
125- break
126- else :
127- ufl_function_space = FunctionSpace (func .ufl_domain (), func .ufl_element ())
128- new_coeff = Coefficient (ufl_function_space , count = i )
108+ # Reconstruct `ufl.Coefficinet`s with count starting at 0.
109+ function_replace_map = {}
110+ for i , func in enumerate (reduced_coefficients ):
111+ for fd in form_data_tuple :
112+ if func in fd .function_replace_map :
113+ coeff = fd .function_replace_map [func ]
114+ new_coeff = Coefficient (coeff .ufl_function_space (), count = i )
129115 function_replace_map [func ] = new_coeff
130- self .reduced_coefficients = reduced_coefficients
131- self .original_coefficient_positions = [i for i , f in enumerate (original_form .coefficients ())
132- if f in self .reduced_coefficients ]
133- self .function_replace_map = function_replace_map
116+ break
117+ else :
118+ ufl_function_space = FunctionSpace (func .ufl_domain (), func .ufl_element ())
119+ new_coeff = Coefficient (ufl_function_space , count = i )
120+ function_replace_map [func ] = new_coeff
121+ self .reduced_coefficients = reduced_coefficients
122+ self .original_coefficient_positions = [i for i , f in enumerate (original_form .coefficients ())
123+ if f in self .reduced_coefficients ]
124+ self .function_replace_map = function_replace_map
134125
135126 # Translate `ufl.IntegralData`s -> `TSFCIntegralData`.
136- intg_data_dict = {}
137- form_data_dict = {}
138- for form_data in form_data_tuple :
127+ intg_data_info_dict = {}
128+ for form_data_index , form_data in enumerate (form_data_tuple ):
139129 for intg_data in form_data .integral_data :
140130 domain = intg_data .domain
141131 integral_type = intg_data .integral_type
142132 subdomain_id = intg_data .subdomain_id
143133 key = (domain , integral_type , subdomain_id )
144- # Add intg_data.
145- intg_data_dict .setdefault (key , []).append (intg_data )
146- # Remember which form_data this intg_data came from.
147- form_data_dict .setdefault (key , []).append (form_data )
134+ # Add (intg_data, form_data, form_data_index).
135+ intg_data_info_dict .setdefault (key , []).append ((intg_data , form_data , form_data_index ))
148136 integral_data_list = []
149- for key in intg_data_dict :
150- intg_data_list = intg_data_dict [key ]
151- form_data_list = form_data_dict [key ]
137+ for key , intg_data_info in intg_data_info_dict .items ():
152138 domain , _ , _ = key
153139 domain_number = original_form .domain_numbering ()[domain ]
154- integral_data_list .append (TSFCIntegralData (key , intg_data_list , form_data_list ,
155- self , domain_number , form_data_function_map ))
140+ integral_data_list .append (TSFCIntegralData (key , intg_data_info ,
141+ self , domain_number , function_tuple ))
156142 self .integral_data = tuple (integral_data_list )
157143
158144
159145class TSFCIntegralData (object ):
160146 r"""Mimics `ufl.IntegralData`.
161147
162148 :arg integral_data_key: (domain, integral_type, subdomain_id) tuple.
163- :arg integral_data_list : A list of `ufl.IntegralData`.
164- :arg form_data_list: A list of `ufl.FormData` .
165- :arg tsfc_form_data: The `TSFCFormData` that will contain this
149+ :arg integral_data_info : A tuple of the lists of integral_data,
150+ form_data, and form_data_index .
151+ :arg tsfc_form_data: The `TSFCFormData` that is to contain this
166152 `TSFCIntegralData` object.
167153 :arg domain_number: The domain number associated with `domain`.
168- :arg form_data_function_map : A map from `ufl.FormData`s to functions.
154+ :arg function_tuple : A tuple of functions.
169155
170- This class mimics `ufl.FormData`, but:
171- * extracts information required by TSFC.
172- * preprocesses integrals so that `KernelBuilder`s only
173- need to deal with raw `ufl.Coefficient`s.
156+ After preprocessing `ufl.IntegralData`s here:
157+ * Only essential information about the `ufl.IntegralData`s is retained.
158+ * TSFC can forget `ufl.IntegralData.enabled_coefficients`,
174159 """
175- def __init__ (self , integral_data_key , integral_data_list , form_data_list , tsfc_form_data , domain_number , form_data_function_map ):
160+ def __init__ (self , integral_data_key , intg_data_info , tsfc_form_data , domain_number , function_tuple ):
176161 self .domain , self .integral_type , self .subdomain_id = integral_data_key
177162 self .domain_number = domain_number
178163 # Gather/preprocess integrals.
179164 integrals = []
180- _integral_to_form_data_map = {}
165+ _integral_index_to_form_data_index = []
181166 functions = set ()
182- for intg_data , form_data in zip ( integral_data_list , form_data_list ) :
167+ for intg_data , form_data , form_data_index in intg_data_info :
183168 for integral in intg_data .integrals :
184169 integrand = integral .integrand ()
185170 # Replace functions with Coefficients here.
186171 integrand = ufl .replace (integrand , tsfc_form_data .function_replace_map )
187172 new_integral = integral .reconstruct (integrand = integrand )
188173 integrals .append (new_integral )
189174 # Remember which form_data this integral is associated with.
190- _integral_to_form_data_map [ new_integral ] = form_data
175+ _integral_index_to_form_data_index . append ( form_data_index )
191176 # Gather functions that are enabled in this `TSFCIntegralData`.
192177 functions .update (f for f , enabled in zip (form_data .reduced_coefficients , intg_data .enabled_coefficients ) if enabled )
193- functions .update (form_data_function_map [ form_data ])
178+ functions .update (function_tuple [ form_data_index ])
194179 self .integrals = tuple (integrals )
195- self ._integral_to_form_data_map = _integral_to_form_data_map
180+ self ._integral_index_to_form_data_index = _integral_index_to_form_data_index
196181 self .arguments = tsfc_form_data .arguments
197182 # This is which coefficient in the original form the
198183 # current coefficient is.
@@ -206,9 +191,9 @@ def __init__(self, integral_data_key, integral_data_list, form_data_list, tsfc_f
206191 self .coefficients = tuple (tsfc_form_data .function_replace_map [f ] for f in functions )
207192 self .coefficient_numbers = tuple (tsfc_form_data .original_coefficient_positions [tsfc_form_data .reduced_coefficients .index (f )] for f in functions )
208193
209- def integral_to_form_data (self , integral ):
210- r"""Return `ufl.FormData` which `integral` is associated with ."""
211- return self ._integral_to_form_data_map [ integral ]
194+ def integral_index_to_form_data_index (self , integral_index ):
195+ r"""Return the form data index given an integral index ."""
196+ return self ._integral_index_to_form_data_index [ integral_index ]
212197
213198
214199def compile_form (form , prefix = "form" , parameters = None , interface = None , coffee = True , diagonal = False ):
@@ -233,7 +218,7 @@ def compile_form(form, prefix="form", parameters=None, interface=None, coffee=Tr
233218 form_data = ufl_utils .compute_form_data (form , complex_mode = complex_mode )
234219 if interface :
235220 interface = partial (interface , function_replace_map = form_data .function_replace_map )
236- tsfc_form_data = TSFCFormData ((form_data , ), form_data . original_form , { form_data : ()}, { form_data : ()} , diagonal )
221+ tsfc_form_data = TSFCFormData ((form_data , ), ((), ), ((), ), form_data . original_form , diagonal )
237222
238223 logger .info (GREEN % "compute_form_data finished in %g seconds." , time .time () - cpu_time )
239224
0 commit comments