Skip to content

Commit b1a3ac2

Browse files
committed
Adaptive Multigrid
1 parent 69e3cb7 commit b1a3ac2

File tree

4 files changed

+873
-4
lines changed

4 files changed

+873
-4
lines changed

firedrake/mg/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from .mesh import * # noqa: F401
2-
from .interface import * # noqa: F401
3-
from .embedded import * # noqa: F401
4-
from .opencascade_mh import * # noqa: F401
1+
from .mesh import * # noqa: F401
2+
from .interface import * # noqa: F401
3+
from .embedded import * # noqa: F401
4+
from .opencascade_mh import * # noqa: F401
5+
from .adaptive_hierarchy import * # noqa: F401
6+
from .adaptive_transfer_manager import * # noqa: F401

firedrake/mg/adaptive_hierarchy.py

Lines changed: 360 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,360 @@
1+
"""
2+
This module contains the class for the AdaptiveMeshHierarchy and
3+
related helper functions
4+
"""
5+
6+
from fractions import Fraction
7+
from collections import defaultdict
8+
import numpy as np
9+
10+
from firedrake.cofunction import Cofunction
11+
from firedrake.function import Function
12+
from firedrake.functionspace import FunctionSpace
13+
from firedrake.mesh import Mesh, Submesh, RelabeledMesh
14+
from firedrake.mg import HierarchyBase
15+
from firedrake.mg.utils import set_level, get_level
16+
from ufl import conditional, gt
17+
18+
__all__ = ["AdaptiveMeshHierarchy"]
19+
20+
21+
class AdaptiveMeshHierarchy(HierarchyBase):
22+
"""
23+
HierarchyBase for hierarchies of adaptively refined meshes
24+
"""
25+
26+
def __init__(self, mesh, refinements_per_level=1, nested=True):
27+
self.meshes = tuple(mesh)
28+
self._meshes = tuple(mesh)
29+
self.submesh_hierarchies = []
30+
self.coarse_to_fine_cells = {}
31+
self.fine_to_coarse_cells = {}
32+
self.fine_to_coarse_cells[Fraction(0, 1)] = None
33+
self.refinements_per_level = refinements_per_level
34+
self.nested = nested
35+
set_level(mesh[0], self, 0)
36+
self.split_cache = {}
37+
38+
def add_mesh(self, mesh):
39+
"""
40+
Adds newly refined mesh into hierarchy.
41+
Then computes the coarse_to_fine and fine_to_coarse mappings.
42+
Constructs intermediate submesh hierarchies with this.
43+
"""
44+
if mesh.topological_dimension() <= 2:
45+
max_children = 4
46+
else:
47+
max_children = 16
48+
self._meshes += tuple(mesh)
49+
self.meshes += tuple(mesh)
50+
coarse_mesh = self.meshes[-2]
51+
level = len(self.meshes)
52+
set_level(self.meshes[-1], self, level - 1)
53+
self._shared_data_cache = defaultdict(dict)
54+
55+
# extract parent child relationships from netgen meshes
56+
(c2f, f2c) = get_c2f_f2c_fd(mesh, coarse_mesh)
57+
c2f_global_key = Fraction(len(self.meshes) - 2, 1)
58+
f2c_global_key = Fraction(len(self.meshes) - 1, 1)
59+
self.coarse_to_fine_cells[c2f_global_key] = c2f
60+
self.fine_to_coarse_cells[f2c_global_key] = np.array(f2c)
61+
62+
# split both the fine and coarse meshes into the submeshes
63+
(coarse_splits, fine_splits, num_children) = split_to_submesh(
64+
mesh, coarse_mesh, c2f, f2c
65+
)
66+
for i in range(1, max_children + 1):
67+
coarse_mesh.mark_entities(coarse_splits[i], i)
68+
mesh.mark_entities(fine_splits[i], int(f"10{i}"))
69+
70+
coarse_indicators = [
71+
coarse_splits[i]
72+
for i in range(1, max_children + 1)
73+
]
74+
coarse_labels = list(range(1, max_children + 1))
75+
coarse_mesh = RelabeledMesh(
76+
coarse_mesh,
77+
coarse_indicators,
78+
coarse_labels,
79+
name="Relabeled_coarse",
80+
)
81+
c_subm = {
82+
j: Submesh(coarse_mesh, coarse_mesh.topology_dm.getDimension(), j)
83+
for j in range(1, max_children + 1)
84+
if any(num_children == j)
85+
}
86+
set_level(coarse_mesh, self, level - 2)
87+
88+
fine_indicators = [
89+
fine_splits[i]
90+
for i in range(1, max_children + 1)
91+
]
92+
fine_labels = list(range(1, max_children + 1))
93+
mesh = RelabeledMesh(
94+
mesh,
95+
fine_indicators,
96+
fine_labels,
97+
)
98+
f_subm = {
99+
int(str(j)[-2:]): Submesh(mesh, mesh.topology_dm.getDimension(), j)
100+
for j in [int("10" + str(i)) for i in range(1, max_children + 1)]
101+
if any(num_children == int(str(j)[-2:]))
102+
}
103+
set_level(mesh, self, level - 1)
104+
105+
# update c2f and f2c for submeshes by mapping numberings
106+
# on full mesh to numberings on coarse mesh
107+
parents_per_child_count = [
108+
len([el for el in c2f if len(el) == j])
109+
for j in range(1, max_children + 1)
110+
] # stores number of parents for each amount of children
111+
c2f_adjusted = {
112+
j: np.zeros((num_parents, j))
113+
for j, num_parents in enumerate(parents_per_child_count, 1)
114+
if num_parents != 0
115+
}
116+
f2c_adjusted = {
117+
j: np.zeros((num_parents * j, 1))
118+
for j, num_parents in enumerate(parents_per_child_count, 1)
119+
if num_parents != 0
120+
}
121+
122+
coarse_full_to_sub_map = {
123+
i: full_to_sub(coarse_mesh, c_subm[i])
124+
for i in c_subm
125+
}
126+
fine_full_to_sub_map = {
127+
j: full_to_sub(mesh, f_subm[j])
128+
for j in f_subm
129+
}
130+
131+
for i, children in enumerate(c2f):
132+
n = len(children)
133+
if 1 <= n <= max_children:
134+
coarse_id_sub = coarse_full_to_sub_map[n][i]
135+
fine_id_sub = fine_full_to_sub_map[n][np.array(children)]
136+
c2f_adjusted[n][coarse_id_sub] = fine_id_sub
137+
138+
for j, parent in enumerate(f2c):
139+
n = num_children[parent].item()
140+
if 1 <= n <= max_children:
141+
fine_id_sub = fine_full_to_sub_map[n][j]
142+
coarse_id_sub = coarse_full_to_sub_map[n][parent.item()]
143+
f2c_adjusted[n][fine_id_sub, 0] = coarse_id_sub
144+
145+
c2f_subm = {
146+
i: {Fraction(0, 1): c2f_adjusted[i].astype(int)}
147+
for i in c2f_adjusted
148+
}
149+
f2c_subm = {i: {Fraction(1, 1): f2c_adjusted[i]} for i in f2c_adjusted}
150+
151+
hierarchy_dict = {
152+
i: HierarchyBase(
153+
[c_subm[i], f_subm[i]], c2f_subm[i], f2c_subm[i], nested=True
154+
)
155+
for i in c_subm
156+
}
157+
self.submesh_hierarchies.append(hierarchy_dict)
158+
159+
def refine(self, refinements):
160+
"""
161+
Refines and adds mesh if input a boolean vector corresponding to cells
162+
"""
163+
ngmesh = self.meshes[-1].netgen_mesh
164+
for i, el in enumerate(ngmesh.Elements2D()):
165+
el.refine = refinements[i]
166+
167+
ngmesh.Refine(adaptive=True)
168+
mesh = Mesh(ngmesh)
169+
self.add_mesh(mesh)
170+
171+
def adapt(self, eta, theta):
172+
"""
173+
Implement Dorfler marking, refines mesh from error estimator
174+
"""
175+
mesh = self.meshes[-1]
176+
W = FunctionSpace(mesh, "DG", 0)
177+
markers = Function(W)
178+
179+
with eta.dat.vec_ro as eta_:
180+
eta_max = eta_.max()[1]
181+
182+
should_refine = conditional(gt(eta, theta * eta_max), 1, 0)
183+
markers.interpolate(should_refine)
184+
185+
refined_mesh = mesh.refine_marked_elements(markers)
186+
self.add_mesh(refined_mesh)
187+
return refined_mesh
188+
189+
def split_function(self, u, child=True):
190+
"""
191+
Split input function across submeshes
192+
"""
193+
V = u.function_space()
194+
full_mesh = V.mesh()
195+
_, level = get_level(full_mesh)
196+
197+
ind = 1 if child else 0
198+
hierarchy_dict = self.submesh_hierarchies[int(level) - ind]
199+
parent_mesh = hierarchy_dict[[*hierarchy_dict][0]].meshes[ind].submesh_parent
200+
parent_space = V.reconstruct(parent_mesh)
201+
u_corr_space = Function(parent_space, val=u.dat)
202+
key = (u, child)
203+
try:
204+
split_functions = self.split_cache[key]
205+
except KeyError:
206+
split_functions = self.split_cache.setdefault(key, {})
207+
208+
for i in hierarchy_dict:
209+
try:
210+
f = split_functions[i].zero()
211+
except KeyError:
212+
V_split = V.reconstruct(mesh=hierarchy_dict[i].meshes[ind])
213+
assert (
214+
V_split.mesh().submesh_parent
215+
== u_corr_space.function_space().mesh()
216+
)
217+
f = split_functions.setdefault(
218+
i,
219+
Function(V_split, name=str(i))
220+
)
221+
222+
f.assign(u_corr_space)
223+
return split_functions
224+
225+
def use_weight(self, V, child):
226+
"""
227+
Counts DoFs across submeshes, computes partition of unity
228+
"""
229+
w = Function(V).assign(1)
230+
splits = self.split_function(w, child)
231+
232+
self.recombine(splits, w, child)
233+
with w.dat.vec as wvec:
234+
wvec.reciprocal()
235+
return w
236+
237+
def recombine(self, split_funcs, f, child=True):
238+
"""
239+
Recombines functions on submeshes back full mesh
240+
"""
241+
V = f.function_space()
242+
f.zero()
243+
parent_mesh = (
244+
split_funcs[[*split_funcs][0]].function_space().mesh().submesh_parent
245+
)
246+
V_label = V.reconstruct(mesh=parent_mesh)
247+
if isinstance(f, Function):
248+
f_label = Function(V_label, val=f.dat)
249+
elif isinstance(f, Cofunction):
250+
f_label = Cofunction(V_label, val=f.dat)
251+
252+
for split_label, val in split_funcs.items():
253+
assert val.function_space().mesh().submesh_parent == parent_mesh
254+
if child:
255+
split_label = int("10" + str(split_label))
256+
if isinstance(f_label, Function):
257+
f_label.assign(val, allow_missing_dofs=True)
258+
else:
259+
curr = Function(f_label.function_space()).assign(
260+
val, allow_missing_dofs=True
261+
)
262+
f_label.assign(f_label + curr) # partition of unity for restriction
263+
return f
264+
265+
266+
def get_c2f_f2c_fd(mesh, coarse_mesh):
267+
"""
268+
Construct coarse->fine and fine->coarse relations by mapping netgen elements to firedrake ones
269+
"""
270+
ngmesh = mesh.netgen_mesh
271+
num_parents = coarse_mesh.num_cells()
272+
273+
if mesh.topology_dm.getDimension() == 2:
274+
parents = ngmesh.parentsurfaceelements.NumPy()
275+
elements = ngmesh.Elements2D()
276+
elif mesh.topology_dm.getDimension() == 3:
277+
parents = ngmesh.parentelements.NumPy()
278+
elements = ngmesh.Elements3D()
279+
else:
280+
raise RuntimeError("Adaptivity not implemented in dimension of mesh")
281+
282+
c2f = [[] for _ in range(num_parents)]
283+
f2c = [[] for _ in range(mesh.num_cells())]
284+
285+
if parents.shape[0] == 0:
286+
raise RuntimeError("Added mesh has not refined any cells from previous mesh")
287+
for l, _ in enumerate(elements):
288+
if parents[l][0] == -1 or l < num_parents:
289+
f2c[mesh._cell_numbering.getOffset(l)].append(
290+
coarse_mesh._cell_numbering.getOffset(l)
291+
)
292+
c2f[coarse_mesh._cell_numbering.getOffset(l)].append(
293+
mesh._cell_numbering.getOffset(l)
294+
)
295+
296+
elif parents[l][0] < num_parents:
297+
fine_ind = mesh._cell_numbering.getOffset(l)
298+
coarse_ind = coarse_mesh._cell_numbering.getOffset(parents[l][0])
299+
f2c[fine_ind].append(coarse_ind)
300+
c2f[coarse_ind].append(fine_ind)
301+
302+
else:
303+
a = parents[parents[l][0]][0]
304+
while a >= num_parents:
305+
a = parents[a][0]
306+
307+
f2c[mesh._cell_numbering.getOffset(l)].append(
308+
coarse_mesh._cell_numbering.getOffset(a)
309+
)
310+
c2f[coarse_mesh._cell_numbering.getOffset(a)].append(
311+
mesh._cell_numbering.getOffset(l)
312+
)
313+
314+
return c2f, np.array(f2c).astype(int)
315+
316+
317+
def split_to_submesh(mesh, coarse_mesh, c2f, f2c):
318+
"""
319+
Computes submesh split from full mesh.
320+
Returns splits which are Functions denoting whether elements
321+
belong to the corresponing submesh (bool)
322+
"""
323+
if mesh.topological_dimension() <= 2:
324+
max_children = 4
325+
else:
326+
max_children = 16
327+
V = FunctionSpace(mesh, "DG", 0)
328+
V2 = FunctionSpace(coarse_mesh, "DG", 0)
329+
coarse_splits = {
330+
i: Function(V2, name=f"{i}_elements") for i in range(1, max_children + 1)
331+
}
332+
fine_splits = {
333+
i: Function(V, name=f"{i}_elements") for i in range(1, max_children + 1)
334+
}
335+
num_children = np.zeros((len(c2f)))
336+
337+
for i, children in enumerate(c2f):
338+
n = len(children)
339+
if 1 <= n <= max_children:
340+
coarse_splits[n].dat.data[i] = 1
341+
num_children[i] = n
342+
343+
for i in range(1, max_children + 1):
344+
fine_splits[i].dat.data[num_children[f2c.squeeze()] == i] = 1
345+
346+
return coarse_splits, fine_splits, num_children
347+
348+
349+
def full_to_sub(mesh, submesh):
350+
"""
351+
Returns the submesh element id associated with the full mesh element id
352+
"""
353+
V1 = FunctionSpace(mesh, "DG", 0)
354+
V2 = FunctionSpace(submesh, "DG", 0)
355+
u1 = Function(V1)
356+
u2 = Function(V2)
357+
u2.dat.data[:] = np.arange(len(u2.dat.data))
358+
u1.assign(u2, allow_missing_dofs=True)
359+
360+
return u1.dat.data.astype(int)

0 commit comments

Comments
 (0)