-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add model parallel for inference #55
base: develop
Are you sure you want to change the base?
Conversation
thanks @cathalobrien - this looks good, just a minor comment from my side. |
@@ -443,55 +474,57 @@ def get_most_recent_datetime(input_fields): | |||
|
|||
# Predict next state of atmosphere | |||
with torch.autocast(device_type=device, dtype=autocast): | |||
y_pred = model.predict_step(input_tensor_torch) | |||
# y_pred = model.predict_step(input_tensor_torch, model_comm_group) | |||
y_pred = model.forward(input_tensor_torch.unsqueeze(2), model_comm_group) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is an unsqueeze op needed here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It expects an ensemble dimension I believe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, yes, somehow i missed that you are now calling forward instead of predict_step.
is predict_step now obsolete? (if so, should it be removed?)
dist.init_process_group( | ||
backend="nccl", | ||
init_method=f"tcp://{addr}:{port}", | ||
timeout=datetime.timedelta(minutes=1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should probably set a longer timeout, O(5-10 mins) ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah seems sensible
if global_rank == 0: | ||
LOGGER.info("World size: %d", world_size) | ||
addr = os.getenv("MASTER_ADDR", "localhost") # localhost should be sufficient to run on a single node | ||
port = os.getenv("MASTER_PORT", 10000 + random.randint(0, 10000)) # random port between 10,000 and 20,000 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the fallback to 10000 + random.randint(0, 10000)
work without consistent seeding across ranks? Maybe we should use SLURM_JOBID
instead?
Lets you run inference over multiple GPUs.
All credit goes to @mishooax . This is his implementation, I just added it to anemoi inference
I compared output for n320 1024c running on 1 GPU vs 4 GPUs and it seems to match.
With this I was able to run 9km inference over 4 nodes with 4 40GB a100s per node.
I would like feedback about how the input tensor is read and how the output tensor is written. Currently all ranks read the input and only rank 0 writes output. Also, at the moment when you run there is lots of duplicated logging
Unfortunately you have to use slurm to launch an inference job on multiple GPUs (as opposed to anemoi training which supports launching interactive jobs with multiple gpus like
anemoi-training train hardware.num_gpus_per_node=<num_gpus>
). I tried launching with 'torchrun' but it didnt work. Happy to look into this more thoughIf you are running over multiple nodes you need to add these lines to your slurm batch script
If you're running over a single node,
localhost
is used as the address.