-
Notifications
You must be signed in to change notification settings - Fork 237
mpi: Add basic2 mode #2307
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
mpi: Add basic2 mode #2307
Changes from all commits
3b8d682
02e3c7b
264d232
65e1ecf
d01e25b
1be01cb
e507e73
1f12612
4c7bbcc
6950fa0
68631a5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -436,33 +436,14 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs): | |
|
||
fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices} | ||
|
||
# Build a mapper `(dim, side, region) -> (size, ofs)` for `f`. `size` and | ||
# `ofs` are symbolic objects. This mapper tells what data values should be | ||
# sent (OWNED) or received (HALO) given dimension and side | ||
mapper = {} | ||
for d0, side, region in product(f.dimensions, (LEFT, RIGHT), (OWNED, HALO)): | ||
if d0 in fixed: | ||
continue | ||
sizes = [] | ||
ofs = [] | ||
for d1 in f.dimensions: | ||
if d1 in fixed: | ||
ofs.append(fixed[d1]) | ||
else: | ||
meta = f._C_get_field(region if d0 is d1 else NOPAD, d1, side) | ||
ofs.append(meta.offset) | ||
sizes.append(meta.size) | ||
mapper[(d0, side, region)] = (sizes, ofs) | ||
mapper = self._make_basic_mapper(f, fixed) | ||
|
||
body = [] | ||
for d in f.dimensions: | ||
if d in fixed: | ||
continue | ||
|
||
name = ''.join('r' if i is d else 'c' for i in distributor.dimensions) | ||
rpeer = FieldFromPointer(name, nb) | ||
name = ''.join('l' if i is d else 'c' for i in distributor.dimensions) | ||
lpeer = FieldFromPointer(name, nb) | ||
rpeer, lpeer = self._make_peers(d, distributor, nb) | ||
|
||
if (d, LEFT) in hse.halos: | ||
# Sending to left, receiving from right | ||
|
@@ -484,6 +465,37 @@ def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs): | |
|
||
return HaloUpdate('haloupdate%s' % key, iet, parameters) | ||
|
||
def _make_basic_mapper(self, f, fixed): | ||
""" | ||
Build a mapper `(dim, side, region) -> (size, ofs)` for `f`. `size` and | ||
`ofs` are symbolic objects. This mapper tells what data values should be | ||
sent (OWNED) or received (HALO) given dimension and side | ||
""" | ||
mapper = {} | ||
for d0, side, region in product(f.dimensions, (LEFT, RIGHT), (OWNED, HALO)): | ||
if d0 in fixed: | ||
continue | ||
sizes = [] | ||
ofs = [] | ||
for d1 in f.dimensions: | ||
if d1 in fixed: | ||
ofs.append(fixed[d1]) | ||
else: | ||
meta = f._C_get_field(region if d0 is d1 else NOPAD, d1, side) | ||
ofs.append(meta.offset) | ||
sizes.append(meta.size) | ||
mapper[(d0, side, region)] = (sizes, ofs) | ||
|
||
return mapper | ||
|
||
def _make_peers(self, d, distributor, nb): | ||
rname = ''.join('r' if i is d else 'c' for i in distributor.dimensions) | ||
rpeer = FieldFromPointer(rname, nb) | ||
lname = ''.join('l' if i is d else 'c' for i in distributor.dimensions) | ||
lpeer = FieldFromPointer(lname, nb) | ||
|
||
return rpeer, lpeer | ||
|
||
def _call_haloupdate(self, name, f, hse, *args): | ||
comm = f.grid.distributor._obj_comm | ||
nb = f.grid.distributor._obj_neighborhood | ||
|
@@ -527,6 +539,118 @@ def _make_body(self, callcompute, remainder, haloupdates, halowaits): | |
return List(body=body) | ||
|
||
|
||
class Basic2HaloExchangeBuilder(BasicHaloExchangeBuilder): | ||
|
||
""" | ||
A BasicHaloExchangeBuilder using pre-allocated buffers for | ||
message size. | ||
|
||
Generates: | ||
|
||
haloupdate() | ||
compute() | ||
""" | ||
|
||
def _make_msg(self, f, hse, key): | ||
# Pass the fixed mapper e.g. {t: otime} | ||
fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices} | ||
|
||
return MPIMsgBasic2('msg%d' % key, f, hse.halos, fixed) | ||
|
||
def _make_sendrecv(self, f, hse, key, msg=None): | ||
georgebisbas marked this conversation as resolved.
Show resolved
Hide resolved
|
||
cast = cast_mapper[(f.c0.dtype, '*')] | ||
comm = f.grid.distributor._obj_comm | ||
|
||
bufg = FieldFromPointer(msg._C_field_bufg, msg) | ||
bufs = FieldFromPointer(msg._C_field_bufs, msg) | ||
|
||
ofsg = [Symbol(name='og%s' % d.root) for d in f.dimensions] | ||
ofss = [Symbol(name='os%s' % d.root) for d in f.dimensions] | ||
|
||
fromrank = Symbol(name='fromrank') | ||
torank = Symbol(name='torank') | ||
|
||
sizes = [FieldFromPointer('%s[%d]' % (msg._C_field_sizes, i), msg) | ||
for i in range(len(f._dist_dimensions))] | ||
|
||
arguments = [cast(bufg)] + sizes + list(f.handles) + ofsg | ||
gather = Gather('gather%s' % key, arguments) | ||
# The `gather` is unnecessary if sending to MPI.PROC_NULL | ||
gather = Conditional(CondNe(torank, Macro('MPI_PROC_NULL')), gather) | ||
|
||
arguments = [cast(bufs)] + sizes + list(f.handles) + ofss | ||
scatter = Scatter('scatter%s' % key, arguments) | ||
# The `scatter` must be guarded as we must not alter the halo values along | ||
# the domain boundary, where the sender is actually MPI.PROC_NULL | ||
scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')), scatter) | ||
|
||
count = reduce(mul, sizes, 1)*dtype_len(f.dtype) | ||
rrecv = Byref(FieldFromPointer(msg._C_field_rrecv, msg)) | ||
rsend = Byref(FieldFromPointer(msg._C_field_rsend, msg)) | ||
recv = IrecvCall([bufs, count, Macro(dtype_to_mpitype(f.dtype)), | ||
fromrank, Integer(13), comm, rrecv]) | ||
send = IsendCall([bufg, count, Macro(dtype_to_mpitype(f.dtype)), | ||
torank, Integer(13), comm, rsend]) | ||
|
||
waitrecv = Call('MPI_Wait', [rrecv, Macro('MPI_STATUS_IGNORE')]) | ||
waitsend = Call('MPI_Wait', [rsend, Macro('MPI_STATUS_IGNORE')]) | ||
|
||
iet = List(body=[recv, gather, send, waitsend, waitrecv, scatter]) | ||
|
||
parameters = (list(f.handles) + ofsg + ofss + [fromrank, torank, comm, msg]) | ||
|
||
return SendRecv('sendrecv%s' % key, iet, parameters, bufg, bufs) | ||
|
||
def _call_sendrecv(self, name, *args, msg=None, haloid=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do you need a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similarly to the same method in OverlapHalo. Each call to sendrecv has its own message indexedpointer |
||
# Drop `sizes` as this HaloExchangeBuilder conveys them through `msg` | ||
f, _, ofsg, ofss, fromrank, torank, comm = args | ||
msg = Byref(IndexedPointer(msg, haloid)) | ||
return Call(name, list(f.handles) + ofsg + ofss + [fromrank, torank, comm, msg]) | ||
|
||
def _make_haloupdate(self, f, hse, key, sendrecv, **kwargs): | ||
distributor = f.grid.distributor | ||
nb = distributor._obj_neighborhood | ||
comm = distributor._obj_comm | ||
|
||
fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices} | ||
|
||
mapper = self._make_basic_mapper(f, fixed) | ||
|
||
body = [] | ||
for d in f.dimensions: | ||
if d in fixed: | ||
continue | ||
|
||
rpeer, lpeer = self._make_peers(d, distributor, nb) | ||
|
||
if (d, LEFT) in hse.halos: | ||
# Sending to left, receiving from right | ||
lsizes, lofs = mapper[(d, LEFT, OWNED)] | ||
rsizes, rofs = mapper[(d, RIGHT, HALO)] | ||
args = [f, lsizes, lofs, rofs, rpeer, lpeer, comm] | ||
body.append(self._call_sendrecv(sendrecv.name, *args, haloid=len(body), | ||
**kwargs)) | ||
|
||
if (d, RIGHT) in hse.halos: | ||
# Sending to right, receiving from left | ||
rsizes, rofs = mapper[(d, RIGHT, OWNED)] | ||
lsizes, lofs = mapper[(d, LEFT, HALO)] | ||
args = [f, rsizes, rofs, lofs, lpeer, rpeer, comm] | ||
body.append(self._call_sendrecv(sendrecv.name, *args, haloid=len(body), | ||
**kwargs)) | ||
|
||
iet = List(body=body) | ||
|
||
parameters = list(f.handles) + [comm, nb] + list(fixed.values()) + [kwargs['msg']] | ||
|
||
return HaloUpdate('haloupdate%s' % key, iet, parameters) | ||
|
||
def _call_haloupdate(self, name, f, hse, msg): | ||
call = super()._call_haloupdate(name, f, hse) | ||
call = call._rebuild(arguments=call.arguments + (msg,)) | ||
return call | ||
|
||
|
||
class DiagHaloExchangeBuilder(BasicHaloExchangeBuilder): | ||
|
||
""" | ||
|
@@ -1003,6 +1127,7 @@ def _call_poke(self, poke): | |
|
||
mpi_registry = { | ||
'basic': BasicHaloExchangeBuilder, | ||
'basic2': Basic2HaloExchangeBuilder, | ||
'diag': DiagHaloExchangeBuilder, | ||
'diag2': Diag2HaloExchangeBuilder, | ||
'overlap': OverlapHaloExchangeBuilder, | ||
|
@@ -1112,7 +1237,7 @@ class MPIRequestObject(LocalObject): | |
dtype = type('MPI_Request', (c_void_p,), {}) | ||
|
||
|
||
class MPIMsg(CompositeObject): | ||
class MPIMsgBase(CompositeObject): | ||
|
||
_C_field_bufs = 'bufs' | ||
_C_field_bufg = 'bufg' | ||
|
@@ -1135,17 +1260,6 @@ class MPIMsg(CompositeObject): | |
|
||
__rargs__ = ('name', 'target', 'halos') | ||
|
||
def __init__(self, name, target, halos): | ||
self._target = target | ||
self._halos = halos | ||
|
||
super().__init__(name, 'msg', self.fields) | ||
|
||
# Required for buffer allocation/deallocation before/after jumping/returning | ||
# to/from C-land | ||
self._allocator = None | ||
self._memfree_args = [] | ||
|
||
def __del__(self): | ||
self._C_memfree() | ||
|
||
|
@@ -1184,6 +1298,17 @@ def _as_number(self, v, args): | |
assert args is not None | ||
return int(subs_op_args(v, args)) | ||
|
||
def _allocate_buffers(self, f, shape, entry): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could use some whitespace to make it more readable There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agree, added a docstring as well |
||
# Allocate the send/recv buffers | ||
entry.sizes = (c_int*len(shape))(*shape) | ||
size = reduce(mul, shape)*dtype_len(self.target.dtype) | ||
ctype = dtype_to_ctype(f.dtype) | ||
entry.bufg, bufg_memfree_args = self._allocator._alloc_C_libcall(size, ctype) | ||
entry.bufs, bufs_memfree_args = self._allocator._alloc_C_libcall(size, ctype) | ||
# The `memfree_args` will be used to deallocate the buffer upon | ||
# returning from C-land | ||
self._memfree_args.extend([bufg_memfree_args, bufs_memfree_args]) | ||
|
||
def _arg_defaults(self, allocator, alias, args=None): | ||
# Lazy initialization if `allocator` is necessary as the `allocator` | ||
# type isn't really known until an Operator is constructed | ||
|
@@ -1201,17 +1326,9 @@ def _arg_defaults(self, allocator, alias, args=None): | |
except AttributeError: | ||
assert side == CENTER | ||
shape.append(self._as_number(f._size_domain[dim], args)) | ||
entry.sizes = (c_int*len(shape))(*shape) | ||
|
||
# Allocate the send/recv buffers | ||
size = reduce(mul, shape)*dtype_len(self.target.dtype) | ||
ctype = dtype_to_ctype(f.dtype) | ||
entry.bufg, bufg_memfree_args = allocator._alloc_C_libcall(size, ctype) | ||
entry.bufs, bufs_memfree_args = allocator._alloc_C_libcall(size, ctype) | ||
|
||
# The `memfree_args` will be used to deallocate the buffer upon | ||
# returning from C-land | ||
self._memfree_args.extend([bufg_memfree_args, bufs_memfree_args]) | ||
self._allocate_buffers(f, shape, entry) | ||
|
||
return {self.name: self.value} | ||
|
||
|
@@ -1232,6 +1349,99 @@ def _arg_apply(self, *args, **kwargs): | |
self._C_memfree() | ||
|
||
|
||
class MPIMsg(MPIMsgBase): | ||
|
||
def __init__(self, name, target, halos): | ||
self._target = target | ||
self._halos = halos | ||
|
||
super().__init__(name, 'msg', self.fields) | ||
|
||
# Required for buffer allocation/deallocation before/after jumping/returning | ||
# to/from C-land | ||
self._allocator = None | ||
self._memfree_args = [] | ||
|
||
def _arg_defaults(self, allocator, alias, args=None): | ||
# Lazy initialization if `allocator` is necessary as the `allocator` | ||
# type isn't really known until an Operator is constructed | ||
self._allocator = allocator | ||
|
||
f = alias or self.target.c0 | ||
for i, halo in enumerate(self.halos): | ||
entry = self.value[i] | ||
|
||
# Buffer shape for this peer | ||
shape = [] | ||
for dim, side in zip(*halo): | ||
try: | ||
shape.append(getattr(f._size_owned[dim], side.name)) | ||
except AttributeError: | ||
assert side is CENTER | ||
shape.append(self._as_number(f._size_domain[dim], args)) | ||
|
||
# Allocate the send/recv buffers | ||
self._allocate_buffers(f, shape, entry) | ||
|
||
return {self.name: self.value} | ||
|
||
|
||
class MPIMsgBasic2(MPIMsgBase): | ||
|
||
def __init__(self, name, target, halos, fixed=None): | ||
self._target = target | ||
self._halos = halos | ||
|
||
super().__init__(name, 'msg', self.fields) | ||
|
||
# Required for buffer allocation/deallocation before/after jumping/returning | ||
# to/from C-land | ||
self._fixed = fixed | ||
self._allocator = None | ||
self._memfree_args = [] | ||
|
||
def _arg_defaults(self, allocator, alias, args=None): | ||
# Lazy initialization if `allocator` is necessary as the `allocator` | ||
# type isn't really known until an Operator is constructed | ||
self._allocator = allocator | ||
|
||
f = alias or self.target.c0 | ||
|
||
fixed = self._fixed | ||
|
||
# Build a mapper `(dim, side, region) -> (size)` for `f`. | ||
mapper = {} | ||
for d0, side, region in product(f.dimensions, (LEFT, RIGHT), (OWNED, HALO)): | ||
if d0 in fixed: | ||
continue | ||
sizes = [] | ||
for d1 in f.dimensions: | ||
if d1 in fixed: | ||
continue | ||
if d0 is d1: | ||
if region is OWNED: | ||
sizes.append(getattr(f._size_owned[d0], side.name)) | ||
elif region is HALO: | ||
sizes.append(getattr(f._size_halo[d0], side.name)) | ||
else: | ||
georgebisbas marked this conversation as resolved.
Show resolved
Hide resolved
|
||
sizes.append(self._as_number(f._size_nopad[d1], args)) | ||
mapper[(d0, side, region)] = sizes | ||
|
||
i = 0 | ||
for d in f.dimensions: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not :
like we have in the other message types? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I followed the basic style, which does not have yet a method to cleanup the redundant halos as in diag2 for example:
I tried this but did not manage to get it working nicely. |
||
if d in fixed: | ||
georgebisbas marked this conversation as resolved.
Show resolved
Hide resolved
|
||
continue | ||
|
||
for side in (LEFT, RIGHT): | ||
if (d, side) in self.halos: | ||
entry = self.value[i] | ||
i += 1 | ||
shape = mapper[(d, side, OWNED)] | ||
self._allocate_buffers(f, shape, entry) | ||
|
||
return {self.name: self.value} | ||
|
||
|
||
class MPIMsgEnriched(MPIMsg): | ||
|
||
_C_field_ofss = 'ofss' | ||
|
Uh oh!
There was an error while loading. Please reload this page.