-
Notifications
You must be signed in to change notification settings - Fork 574
[Question]: activation offload won't work for torch version < 2.5 #2456
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
hey @Irvingwangjr , glad that you are interested. We only guarantee that it works with torch nightlies and the latest torch release. Specifically for activation offloading, use_streams only work for >2.5:
I am not sure which PR changed the behavior of registered_hook :/ |
If there are more questions, please feel free to reopen the issue! |
Many thanks for this! I also wanna report a bug here, reproduction script
for some ops, it needs some tensors to be on cpu:
I think a better way is store the device and bring them back to the specific device in unpack function
|
https://github.com/tgale96/grouped_gemm |
@Irvingwangjr Ah, good point. Would it then make more sense to only offload if the tensor is not on CPU? |
thanks @Irvingwangjr , i asked @janeyx99 if she has availability to take a look. The way we use it in torchtune is that we only enable it with activation checkpointing + we pass the input tensor to the transformer block. This tensor is always on GPU (unless you are training the model on CPU, which doesnt make much sense). Can you explain your use case anymore? Not sure if the bug you mentioned will ever occur in torchtune. Thank you! |
sure, I actually met this problem when I trying to integrate the activation checkpoint and offload to this ops: the parameter 'batch_sizes' is a cpu tensor indicated how the input is splited to groups. and the current logic will break this since it will put the 'batch_sizes' to gpu and when doing the following re-computation, it will break for forward pass since the MOE and dpskV3 is becoming a trend, torchtune might be also met this problem (if use this ops) |
yeah I actually patch the code like this:
and for the unpack function, I think it doesn't need modification since cpu tensor will go to the 'not modified' branch. But I'm not sure what's the different when you call tensor.to("cuda") and tensor.to("cuda:0"); if that makes no different, I think we can ignore the device |
@Irvingwangjr if convenient can you check that this patch #2466 does the trick? I am specializing on CUDA here because our streaming logic only works in CUDA. The difference between cuda and cuda:0 is that cuda:0 specifies the particular GPU, which can be limiting or unavailable dependent on user env so it is better to not use "cuda:0" specifically. |
Look good to me ! But I still have question about the device. Lets say if a tensor is on device 'cuda:3'; then we move then to cpu and bring it back by calling tensor.to("cuda"), which device will the system bring? Does it depends on some env variables like local_rank? |
the registered hook didn't trigger for torch version under 2.5, it works for torch 2.5 and 2.6.
Could you kindly point out which PR change the behavior of registere_hook? We want to use this for torch 2.4
The text was updated successfully, but these errors were encountered: