Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Address pytorch versioning issues. #820

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

coreyjadams
Copy link
Collaborator

Many new features of physicsnemo's distributed utilities, targeting domain parallelism, require pytorch's DTensor package which was introduced in pytorch 2.6.0. But, we don't want to limit physicsnemo usage unnecessarily.

This commit introduces version checking utilities, which are then aplied to ShardTensor. If torch is below 2.6.0, the distributed utilities will not import ShardTensor but will still work. If a user attempts to import ShardTensor directly, avoiding the init.py file, the version checking utilities will raise an exception.

Tests on shard tensor are likewise skipped if torch 2.6.0 is not installed.

Finally, an additional test file is included to validate the version checking tools.

PhysicsNeMo Pull Request

Closes #815

Description

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • An issue is linked to this pull request.

Dependencies

Many new features of physicsnemo's distributed utilities, targeting domain parallelism,
require pytorch's DTensor package which was introduced in pytorch 2.6.0.  But, we don't
want to limit physicsnemo usage unnecessarily.

This commit introduces version checking utilities, which are then aplied to ShardTensor.
If torch is below 2.6.0, the distributed utilities will not import ShardTensor but
will still work.  If a user attempts to import ShardTensor directly, avoiding the
__init__.py  file, the version checking utilities will raise an exception.

Tests on shard tensor are likewise skipped if torch 2.6.0 is not installed.

Finally, an additional test file is included to validate the version checking tools.
@coreyjadams
Copy link
Collaborator Author

@NickGeneva can you check if this solves your earth_2 issues? One item that may be outstanding is that DeviceMesh, which is now used in DistributedManager, was introduced in pytorch 2.2.0. I suspect that is ~OK, and if not we could bump the minimum pytorch version of physicsnemo to 2.2 (not all the way to 2.6, as needed for ShardTensor).

My local testing saw nearly all tests passing but a crash in the one test where the torch.distributed.init is called twice. I believe that's a pytorch bug but I want to see what the CI does with it.

@coreyjadams
Copy link
Collaborator Author

/blossom-ci

- change shard tensor minimum version to 2.5.9 to accomodate alpha release of 2.6.0a
- set minimum pytorch version for DeviceMesh to 2.4.0
- introduce function decorator that raises an exception when unavailable functions are called.
- adds a little more protection in the tests to differntiate,
@coreyjadams
Copy link
Collaborator Author

I've updated to include multiple levels of checking:

  • DeviceMesh requires torch >= 2.4.0
  • ShardTensor requires torch > 2.5.9. Note that torch version 2.5.9 doesn't exist. This is to allow alpha versions from 2.6.0.a.... which do work.

DistributedManager API is unchanged, but several functions are now wrapped in a modulus.utils.version_check.require_version decorator to force them to torch >= 2.4.0. If the user calls these functions, it raises an ImportError. Importing and using DistributedManager with functionality that was available before the 25.03 release should not be afffected.

Testing on ORD, I have the following results for the following containers from ngc:

  • Pytorch-24.01, torch version 2.2.0
    • Single GPU tests pass except test_mesh_datapipe which needs X11.
    • MultiGPU tests run.
      • all tests pass except mesh_datapipe tests (same reason as above)
      • shard tensor tests are skipped (not even collected) (test_shard_tensor...).
      • device_mesh tests are skipped (not even collected) (test_mesh.py)

I'll let the CI test 2.6.0a.

@coreyjadams
Copy link
Collaborator Author

/blossom-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

🐛[BUG]: Older torch versions do not support the latest distributed tools
1 participant