-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Clean up for JOSS paper #19
Conversation
9d50830
to
2bd0b2b
Compare
2bd0b2b
to
3b414bb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple of comments and a request to remove the reduce option.
jaxdecomp/_src/halo.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you remove the halo_reduce
option?
It is unclear what that option is supposed to be doing. Reading the docstring does not tell me what is supposed to happen, I don't think the users can use safely. And it also is implemented in jax, so they can do their own reduction operation if they want to after the halo exchange.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have actually removed it in another branch.
I implemented it in a much clever way in jaxPM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to go
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the late review.
Very nice job on the documentation and typing.
IMO the code looks very clean and ready for the JOSS review.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to get rid of the template, it can guide a future maintainer in the process.
@@ -1,19 +1,18 @@ | |||
# Change log | |||
|
|||
|
|||
<!-- Template for documenting changes | |||
## jaxdecomp 0.0.1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we really releasing a working code with several new additions as v0.0.1
?
I would personally prefer something like 0.1.0
and if we are going for https://semver.org/ they would even prefer changing the major version since some of the changes are not backward compatible (to my knowledge).
I know this is a big leap so I would go for 0.1 if you guys agree.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have enough open source experience to know the implication of a minor vs patch version
But I feel more like this was a minor version rather than a patch
(Then again the major version will always be 0 ..)
* Added custom partitioning for slice_pad and slice_unpad | ||
* Add example for multi-host FFTs in `examples/jaxdecomp_lpt.py` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Careful with consistency Added or Add but not both.
set(CUDECOMP_CUDA_CC_LIST "70;80" CACHE STRING "List of CUDA compute capabilities to build cuDecomp for.") | ||
|
||
# 70: Volta, 80: Ampere, 89: RTX 4060 | ||
set(CUDECOMP_CUDA_CC_LIST "70;80;89" CACHE STRING "List of CUDA compute capabilities to build cuDecomp for.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we already prepare for H100 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right.
I will add 90
return (double_precision == other.double_precision && halo_extents[0] == other.halo_extents[0] && | ||
halo_extents[1] == other.halo_extents[1] && halo_extents[2] == other.halo_extents[2] && | ||
halo_periods[0] == other.halo_periods[0] && halo_periods[1] == other.halo_periods[1] && | ||
halo_periods[2] == other.halo_periods[2] && axis == other.axis && | ||
config.gdims[0] == other.config.gdims[0] && config.gdims[1] == other.config.gdims[1] && | ||
config.gdims[2] == other.config.gdims[2] && config.pdims[0] == other.config.pdims[0] && | ||
config.pdims[1] == other.config.pdims[1]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you put one test per line for readability ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The formatter does the magic here
Its google I think .. I will look into it
def abstract(x: Array, fft_type: xla_client.FftType, pdims: Tuple[int, int], | ||
global_shape: Tuple[int, int, | ||
int], adjoint: bool) -> ShapedArray: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Weird formatting
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yapf..
I use pre-commit to format
match fft_type: | ||
case xla_client.FftType.FFT: | ||
# FFT is X to Y to Z so Z-Pencil is returned | ||
# Except if we are doing a YZ slab in which case we return a Y-Pencil | ||
transpose_shape = (1, 2, 0) | ||
transposed_pdims = pdims | ||
case xla_client.FftType.IFFT: | ||
# IFFT is Z to X to Y so X-Pencil is returned | ||
# In YZ slab case we only need one transposition back to get the X-Pencil | ||
transpose_shape = (2, 0, 1) | ||
transposed_pdims = pdims | ||
case _: | ||
raise TypeError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sure to add Python ≥ 3.10 in the requirements to be allowed to use such pattern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point
Major clean-up for JOSS paper