Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add model parallel for inference #55

Draft
wants to merge 5 commits into
base: develop
Choose a base branch
from

Conversation

cathalobrien
Copy link
Contributor

@cathalobrien cathalobrien commented Nov 26, 2024

Lets you run inference over multiple GPUs.

All credit goes to @mishooax . This is his implementation, I just added it to anemoi inference

I compared output for n320 1024c running on 1 GPU vs 4 GPUs and it seems to match.

With this I was able to run 9km inference over 4 nodes with 4 40GB a100s per node.

I would like feedback about how the input tensor is read and how the output tensor is written. Currently all ranks read the input and only rank 0 writes output. Also, at the moment when you run there is lots of duplicated logging

Unfortunately you have to use slurm to launch an inference job on multiple GPUs (as opposed to anemoi training which supports launching interactive jobs with multiple gpus like anemoi-training train hardware.num_gpus_per_node=<num_gpus>). I tried launching with 'torchrun' but it didnt work. Happy to look into this more though

If you are running over multiple nodes you need to add these lines to your slurm batch script

MASTER_ADDR=$(scontrol show hostname $SLURM_NODELIST | head -n 1)
export MASTER_ADDR=$(nslookup $MASTER_ADDR | grep -oP '(?<=Address: ).*')

If you're running over a single node, localhost is used as the address.

@mishooax
Copy link
Member

mishooax commented Dec 5, 2024

thanks @cathalobrien - this looks good, just a minor comment from my side.
sadly i don't have the time to test this - if someone else wants to, please go ahead.

@@ -443,55 +474,57 @@ def get_most_recent_datetime(input_fields):

# Predict next state of atmosphere
with torch.autocast(device_type=device, dtype=autocast):
y_pred = model.predict_step(input_tensor_torch)
# y_pred = model.predict_step(input_tensor_torch, model_comm_group)
y_pred = model.forward(input_tensor_torch.unsqueeze(2), model_comm_group)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is an unsqueeze op needed here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It expects an ensemble dimension I believe?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, yes, somehow i missed that you are now calling forward instead of predict_step.
is predict_step now obsolete? (if so, should it be removed?)

dist.init_process_group(
backend="nccl",
init_method=f"tcp://{addr}:{port}",
timeout=datetime.timedelta(minutes=1),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably set a longer timeout, O(5-10 mins) ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah seems sensible

if global_rank == 0:
LOGGER.info("World size: %d", world_size)
addr = os.getenv("MASTER_ADDR", "localhost") # localhost should be sufficient to run on a single node
port = os.getenv("MASTER_PORT", 10000 + random.randint(0, 10000)) # random port between 10,000 and 20,000
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the fallback to 10000 + random.randint(0, 10000) work without consistent seeding across ranks? Maybe we should use SLURM_JOBID instead?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants