Caching Mechanism For DoMINO Training #805
Open
+884
−338
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Modulus Pull Request
Description
Summary
Introduces a
CachedDoMINODataset
for DoMINO which decreases the time of a training epoch (by over 50x in a common training scenario on a H100s). An example of a config set up to use caching is provided incached.yaml
.Annotating a timeline of a
surface
training run of DoMINO on the DrivaerML dataset on an H100, we can see the overwhelming portion of the time is spent in the dataloader (within which most of the time is neighbor calculation):In this single sample: 18.92 seconds in the
__getitem__
call, 0.16s with the GPU active doing training (as the H100 powers through the kernels), so over 100x overhead. For 52 training samples on one H100 on the DrivaerML, this makes the training epoch take >15 minutes (983.24 seconds on a arbitrarily chosen epoch), so a training run of 500 epochs takes over 5 days.This overhead is exacerbated by the fact the dataloader contains some GPU-based WARP operations, which mean we cannot have dataloader workers.
In order to mitigate this, we introduce
CachedDoMINODataset
, and a newcache_data.py
stage. Essentially we do all of the preprocessing work formerly done inDoMINODataPipe
's__getitem__
except for sampling incache_data.py
, then at train time just read in the cached data and sample it. In order to keep filesize relatively small we also only store the neighbor indices (and compute the neighbor properties, such as coordinates, at load).Here is a timeline of the same sample but using
CachedDoMINODataset
:~2.29s in dataloading for a single sample, which is over a 8x improvement. This brings the training time for a 500 epoch run down to under a day. Just as importantly, since the GPU-based WARP operations are now in the caching phase, we can now set
num_workers
to something other than 0, and unlock some parallelism and data pipelining. Setting to12
, we can get further improvement:Here we see 12 samples being handled efficiently in sequence, then a gap as the dataloading workers finish preparing the next sample, then the next 12 samples handled efficiently. The 12 samples plus the wait for refill is ~2.69s, which is only a bit more than a single sample with
num_workers=0
. This brings a single training epoch (52 samples) down to ~13.7 seconds (since 12 doesn't divide 52 the 4 final remainder samples end up taking more time than the rest). This means an entire training run is brought down to ~2 hours (from over 5 days without caching).Checklist
Dependencies