Skip to content

[Question]: activation offload won't work for torch version < 2.5 #2456

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
Irvingwangjr opened this issue Mar 4, 2025 · 10 comments
Open
Labels
discussion Start a discussion enhancement New feature or request

Comments

@Irvingwangjr
Copy link

Image
the registered hook didn't trigger for torch version under 2.5, it works for torch 2.5 and 2.6.
Could you kindly point out which PR change the behavior of registere_hook? We want to use this for torch 2.4

@felipemello1
Copy link
Contributor

felipemello1 commented Mar 4, 2025

hey @Irvingwangjr , glad that you are interested. We only guarantee that it works with torch nightlies and the latest torch release. Specifically for activation offloading, use_streams only work for >2.5:

after torch-2.5.0.]. Default: True.

I am not sure which PR changed the behavior of registered_hook :/

@felipemello1
Copy link
Contributor

If there are more questions, please feel free to reopen the issue!

@Irvingwangjr
Copy link
Author

Irvingwangjr commented Mar 5, 2025

Many thanks for this! I also wanna report a bug here,

reproduction script

def test_offloading_works_with_cpu_tensors() -> None:

    class SomefuncNeedCpuTensors(torch.autograd.Function):
        @staticmethod
        def forward(ctx, cpu_tenosr):
            assert cpu_tenosr.device == torch.device("cpu")
            ctx.save_for_backward(cpu_tenosr)
            return torch.rand_like(cpu_tenosr)
        @staticmethod
        def backward(ctx, dy):
            corrupter = ctx.saved_tensors[0]
            assert corrupter.device == torch.device("cpu")
            return torch.rand_like(corrupter)
    def fwd(c):
        a = SomefuncNeedCpuTensors.apply(c)
        return a.sum()
    tensor_c = torch.ones(10102, 1024, device="cpu", requires_grad=True)
    ctx = OffloadActivations(use_streams=False)
    with ctx:
        loss_c = fwd(tensor_c)
    # delete the fwd stash to avoid our peek-in-fwd-stash heuristic in the bwd
    ctx.fwd_stash = {}
    loss_c.backward()

for some ops, it needs some tensors to be on cpu:
and current code ignore this and will bring then back to gpu as below

  # Kick off the process to bring tensors back
  with torch.cuda.stream(self.s1):
      gpu_tensor = maybe_gpu_tensor.to("cuda", non_blocking=True)
      maybe_gpu_tensor = gpu_tensor

I think a better way is store the device and bring them back to the specific device in unpack function

def pack_hook(x):
    print("pack_hook",x)
    return (x.device, x.cpu())

def unpack_hook(packed):
    print("unpack_hook",packed)
    device, tensor = packed
    return tensor.to(device)

@Irvingwangjr
Copy link
Author

https://github.com/tgale96/grouped_gemm
This ops is a real world example where the GMM ops need a parameter named batch_sizes and that needs to be cpu tensor

@janeyx99
Copy link
Contributor

janeyx99 commented Mar 5, 2025

@Irvingwangjr Ah, good point. Would it then make more sense to only offload if the tensor is not on CPU?

@felipemello1
Copy link
Contributor

felipemello1 commented Mar 5, 2025

thanks @Irvingwangjr , i asked @janeyx99 if she has availability to take a look.

The way we use it in torchtune is that we only enable it with activation checkpointing + we pass the input tensor to the transformer block. This tensor is always on GPU (unless you are training the model on CPU, which doesnt make much sense).

Can you explain your use case anymore? Not sure if the bug you mentioned will ever occur in torchtune. Thank you!

@Irvingwangjr
Copy link
Author

Irvingwangjr commented Mar 6, 2025

thanks @Irvingwangjr , i asked @janeyx99 if she has availability to take a look.

The way we use it in torchtune is that we only enable it with activation checkpointing + we pass the input tensor to the transformer block. This tensor is always on GPU (unless you are training the model on CPU, which doesnt make much sense).

Can you explain your use case anymore? Not sure if the bug you mentioned will ever occur in torchtune. Thank you!

sure, I actually met this problem when I trying to integrate the activation checkpoint and offload to this ops:
https://github.com/tgale96/grouped_gemm/blob/main/grouped_gemm/ops_test.py#L154

the parameter 'batch_sizes' is a cpu tensor indicated how the input is splited to groups.

and the current logic will break this since it will put the 'batch_sizes' to gpu and when doing the following re-computation, it will break for forward pass

since the MOE and dpskV3 is becoming a trend, torchtune might be also met this problem (if use this ops)

@Irvingwangjr
Copy link
Author

Irvingwangjr commented Mar 6, 2025

@Irvingwangjr Ah, good point. Would it then make more sense to only offload if the tensor is not on CPU?

yeah I actually patch the code like this:

        def pack_tensor(activation: torch.Tensor) -> int:
            ...
            num_bytes = get_num_bytes_tensor(activation)
            tensor_id = get_tensor_id()
            device = activation.device
            if num_bytes >= self.min_tensor_size_bytes and (
                not isinstance(activation, torch.nn.Parameter)
                and not isinstance(activation, torch.nn.parameter.Buffer)
            ) and device.type !="cpu": 
                   .....
                   # offload ops
            else:
                self.tracker[tensor_id] = (
                    activation,
                    False,
                )  # False = not modified, tensor is as is

            return (tensor_id,device)

and for the unpack function, I think it doesn't need modification since cpu tensor will go to the 'not modified' branch. But I'm not sure what's the different when you call tensor.to("cuda") and tensor.to("cuda:0"); if that makes no different, I think we can ignore the device

@janeyx99
Copy link
Contributor

janeyx99 commented Mar 6, 2025

@Irvingwangjr if convenient can you check that this patch #2466 does the trick? I am specializing on CUDA here because our streaming logic only works in CUDA.

The difference between cuda and cuda:0 is that cuda:0 specifies the particular GPU, which can be limiting or unavailable dependent on user env so it is better to not use "cuda:0" specifically.

@Irvingwangjr
Copy link
Author

@Irvingwangjr if convenient can you check that this patch #2466 does the trick? I am specializing on CUDA here because our streaming logic only works in CUDA.

The difference between cuda and cuda:0 is that cuda:0 specifies the particular GPU, which can be limiting or unavailable dependent on user env so it is better to not use "cuda:0" specifically.

Look good to me !

But I still have question about the device. Lets say if a tensor is on device 'cuda:3'; then we move then to cpu and bring it back by calling tensor.to("cuda"), which device will the system bring? Does it depends on some env variables like local_rank?

@felipemello1 felipemello1 added enhancement New feature or request discussion Start a discussion labels Mar 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
discussion Start a discussion enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants