Delay transform refactor and performance improvements#321
Conversation
88099f5 to
edf6ebd
Compare
d9f36d3 to
24dded8
Compare
7bc3fd2 to
2ff0c2d
Compare
24dded8 to
89e3ba8
Compare
f3e2e31 to
6bab35f
Compare
ketiltrout
left a comment
There was a problem hiding this comment.
For the most part. I think this is a good refactoring.
I do have some issues with the way you've implemented DelayTransformContainerMixin as I try to describe below.
Pretty much all of the delay transform and power spectrum classes have had their names changed, and I have not added stubs for backwards compatibility with the intention of forcing people to change. I think the new names should be clearer.
I'll leave commenting further on this for the people actually making use of draco, but I think there is a time and place for breaking backwards compatibility like this. The counter-argument, though is: you don't want to encourage people not upgrading to save themselves the work of having to refactor their code.
|
I've made a few changes which hopefully address all comments:
|
c4ccfa3 to
1a18dd1
Compare
ssiegelx
left a comment
There was a problem hiding this comment.
I think this cleans up delay.py nicely and makes it quite a bit easier to follow.
I’m not able to track all of the changes, so it would be good to test and confirm that we get the same results as before.
I'm also unsure about the dependency on JAX. What kind of speed-up did you see in the Gibbs sampler on Cedar? (I think you mentioned a 1.5x improvement on your laptop.) Maybe @Arnab-half-blood-prince can weigh in on whether this is a bottleneck in any of his analyses. I believe we’re using the NRML in the daily pipeline?
One other thing I noticed: as currently written, the Gibbs sampler doesn’t seem to be using its rng argument at all, instead generating random seeds using JAX within delay_power_spectrum_gibbs. This means we lose the ability to track the exact random seed used to generate the spectrum and reproduce the results. In fact, the saved seed in the file would be incorrect. Would there be a way to create a JAX subclass of RandomTask to handle this?
|
Good point about the random seed. I think I could come up with a way to work around this, but it would also be easy to revert and remove jax, if that performance improvement isn't significant enough to warrant the extra dependency and complexity |
|
If we don't want to revert it completely, how much work would it be to split the JAX stuff out into its own PR? |
|
@ssiegelx The Gibbs sampler is used only once for the data, and then we use that Power spectrum, fed into weiner filter and then generate delay spectrum for data, 21cm and noise. So, the estimation of PS is only once. All the delay spectrum generation, with Wiener filter is super fast, so not a bottleneck in any analysis. But, I think, we should check this with real data and see is there any difference between earlier version with this new one. |
I made it its own commit, so it should be easy. I think I'll do that |
|
I've moved the |
|
Now that the |
ketiltrout
left a comment
There was a problem hiding this comment.
This does seem to improve things.
I also appreciate the number of comments you're adding to the code.
This refactor does not change the base delay transform, but it aims to separate code which deals with input and output container logic from code which deals with different estimators. It also tries to better separate direct delay transforms, which can also return a power spectrum, from more complicated algorithms used to directly compute a power spectrum.
- Switch to a matern 1.5 prior, which shows faster convergance and better behaviour around large deltas. - Scale the initial fourier transform guess to avoid local minima near initial guess. - Enforce 64-bit precision in the nrml estimator for numerical stability. - Move gaussian process prior kernels to separate module - Remove 'SmoothnessRegulariser' as it is implemented in gaussian prior
This PR refactors certain parts of the delay transform/power spectrum code, and provides performance improvements to the NRML, Gibbs, and Wiener filter estimators. It's based against #318, so that needs to be reviewed/merged first.
Changes
NRML Estimator
This PR adds two functional changes to the NRML estimator:
LinAlgErrorandValueErrorfailure mode during the minimisation.There's also a bit of a refactor in
delayopt.pyto reduce repeated/unnecessary code.Refactor
The main goal of this refactor is to separate the code which handles input/output containers from the code which computes a delay transform or power spectrum, and to split the Gibbs and NRML estimators into their own classes. I also remove the Stokes I estimator in favour of using the
transform.StokesItask added in #318 and feeding that into a general power spectrum estimator.The git diff for this refactor is a bit rough to look at because of how things have been moved around, so I suggest looking at the code directly. There are essentially no changes to the function of the individual methods.
Performance
The most drastic improvement is to the NRML estimator, with roughly a 3x speed improvement in most cases. There are still cases where convergence is slow, which seems to happen when the noise weights are below a certain level and the estimator starts trying to remove all high-delay power. This is something that needs further analysis, and isn't necessarily an issue with the estimator itself.
The Wiener filter is already fast, so this speedup doesn't matter too much. My initial goal was to see if I could speed it up enough to support a variable frequency mask in time - this is possible, but running on my macbook air takes ~1.5 minutes per baseline, as opposed to <1 second. I haven't implemented that here, but it's straightforward to change.
The Gibbs sampler is mostly just improved through the use of theThis is moved to #336jaxlibrary. I'm generally able to get around a 1.5x speedup on my laptop, although I expect this to potentially be better on Cedar when more cores are available. However, this does add an extra dependency, and the NRML estimator performs well enough that we maybe don't have to bother with minor speed improvements to the Gibbs sampler (I see around 4.5 seconds vs 20 seconds for NRML vs Gibbs). I'm open to opinions about reverting this.Deprecations
Pretty much all of the delay transform and power spectrum classes have had their names changed, and I have not added stubs for backwards compatibility with the intention of forcing people to change. I think the new names should be clearer. I can revert this is people would like.