-
Notifications
You must be signed in to change notification settings - Fork 1
Address issue 21 and add distributed_tensor module #22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Removed duplicate description of distributed scalers and corrected DataFrame creation and reading methods.
This module mirrors the code structure and methodology of distributed.py, but focuses specifically on implementing distributed tensor-scaling classes for PyTorch. Obsolete attributes (e.g., self.is_array) and unused methods (such as extract_array, get_column_order, and package_transformed_x) from distributed.py have been removed. The extract_x_columns method has also been simplified. For the fit method, input tensors are expected to be free of NaN values—a reasonable requirement since training datasets should not contain NaNs. The module requires PyTorch 2.8.0, which is enforced via an assertion at initialization.
save_scaler is commented out for now, as the custom serialization for tensors still needs to be built.
Moving the tests for the distributed_tensor module to a separate script.
Add unit tests for DStandardScalerTensor and DMinMaxScalerTensor, following the example in distributed_test.py for DStandardScaler and DMinMaxScaler.
|
Besides addressing issue #21, I have also added the distributed_tensor module DStandardScalerTensor and DMinMaxScalerTensor are also tested with the example in the docs and produced identical results (see screenshots below). But keep in mind that in the example, Happy Thanksgiving! :)
|
djgagne
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have some small requested changes to aid in CI test passage and functionality with more than a base version of pytorch.
The restriction to PyTorch 2.8.0 applied only to an early iteration of the code and is no longer relevant. According to the documentation, the "unbiased" argument in torch.var was renamed to "correction" beginning with PyTorch 2.0; therefore, impose a version minimum requirement of 2.0.0. Tested the module with the latest version 2.9.1, and other versions >= 2.0.0 worked fine.
|
Comments addressed, and CI tests passed. |
djgagne
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM


see #21