Skip to content

Delay transform refactor and performance improvements#321

Merged
ljgray merged 5 commits intomasterfrom
ljg/delay_perf
Mar 25, 2025
Merged

Delay transform refactor and performance improvements#321
ljgray merged 5 commits intomasterfrom
ljg/delay_perf

Conversation

@ljgray
Copy link
Copy Markdown
Contributor

@ljgray ljgray commented Feb 6, 2025

This PR refactors certain parts of the delay transform/power spectrum code, and provides performance improvements to the NRML, Gibbs, and Wiener filter estimators. It's based against #318, so that needs to be reviewed/merged first.

Changes

NRML Estimator

This PR adds two functional changes to the NRML estimator:

  1. Switches to using a matern prior rather than a gaussian. Using the matern kernel results in faster convergence and reduces excessive smoothing introduced by the gaussian prior. It also ends up being a bit less sensitive to specific hyperparameters, and is better able to model large (multiple orders of magnitude) jumps in power.
  2. Imposes numerical bounds on the minimization parameters to eliminate overflows in the exponential. This isn't meant to be a true bounded minimization, just a way to force numerical stability. I've found that this seems to eliminate the LinAlgError and ValueError failure mode during the minimisation.

There's also a bit of a refactor in delayopt.py to reduce repeated/unnecessary code.

Refactor

The main goal of this refactor is to separate the code which handles input/output containers from the code which computes a delay transform or power spectrum, and to split the Gibbs and NRML estimators into their own classes. I also remove the Stokes I estimator in favour of using the transform.StokesI task added in #318 and feeding that into a general power spectrum estimator.

The git diff for this refactor is a bit rough to look at because of how things have been moved around, so I suggest looking at the code directly. There are essentially no changes to the function of the individual methods.

Performance

The most drastic improvement is to the NRML estimator, with roughly a 3x speed improvement in most cases. There are still cases where convergence is slow, which seems to happen when the noise weights are below a certain level and the estimator starts trying to remove all high-delay power. This is something that needs further analysis, and isn't necessarily an issue with the estimator itself.

The Wiener filter is already fast, so this speedup doesn't matter too much. My initial goal was to see if I could speed it up enough to support a variable frequency mask in time - this is possible, but running on my macbook air takes ~1.5 minutes per baseline, as opposed to <1 second. I haven't implemented that here, but it's straightforward to change.

The Gibbs sampler is mostly just improved through the use of the jax library. I'm generally able to get around a 1.5x speedup on my laptop, although I expect this to potentially be better on Cedar when more cores are available. However, this does add an extra dependency, and the NRML estimator performs well enough that we maybe don't have to bother with minor speed improvements to the Gibbs sampler (I see around 4.5 seconds vs 20 seconds for NRML vs Gibbs). I'm open to opinions about reverting this. This is moved to #336

Deprecations

Pretty much all of the delay transform and power spectrum classes have had their names changed, and I have not added stubs for backwards compatibility with the intention of forcing people to change. I think the new names should be clearer. I can revert this is people would like.

Comment thread draco/analysis/delay.py Outdated
@ljgray ljgray force-pushed the ljg/delay_perf branch 6 times, most recently from d9f36d3 to 24dded8 Compare February 13, 2025 01:43
@ljgray ljgray marked this pull request as ready for review March 4, 2025 20:32
@ljgray ljgray force-pushed the ljg/dpss-pspec branch 4 times, most recently from f3e2e31 to 6bab35f Compare March 10, 2025 22:28
Base automatically changed from ljg/dpss-pspec to master March 10, 2025 23:36
Copy link
Copy Markdown
Member

@ketiltrout ketiltrout left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the most part. I think this is a good refactoring.

I do have some issues with the way you've implemented DelayTransformContainerMixin as I try to describe below.

Pretty much all of the delay transform and power spectrum classes have had their names changed, and I have not added stubs for backwards compatibility with the intention of forcing people to change. I think the new names should be clearer.

I'll leave commenting further on this for the people actually making use of draco, but I think there is a time and place for breaking backwards compatibility like this. The counter-argument, though is: you don't want to encourage people not upgrading to save themselves the work of having to refactor their code.

Comment thread draco/analysis/delay.py Outdated
Comment thread draco/analysis/delay.py Outdated
Comment thread draco/analysis/delay.py Outdated
Comment thread draco/analysis/delayopt.py
Comment thread draco/analysis/delayopt.py
Comment thread draco/analysis/delay.py Outdated
@ljgray
Copy link
Copy Markdown
Contributor Author

ljgray commented Mar 18, 2025

I've made a few changes which hopefully address all comments:

  • Rename the DelayTransform classes to DelaySpectrum to be consistent with naming
  • Made a DelaySpectrumBase, which inherits from the DelaySpectrumContainer... class to be consistent with DelayPowerSpectrumBase
  • Removed the powerspectrum option entirely from DelaySpectrum classes, and instead added a simple task to produce a power spectrum from a delay spectrum. The only downside here is that it requires computing the entire delay spectrum first, rather than doing it per baseline

@ljgray ljgray requested a review from ketiltrout March 18, 2025 01:05
@ljgray ljgray force-pushed the ljg/delay_perf branch 5 times, most recently from c4ccfa3 to 1a18dd1 Compare March 20, 2025 17:56
Copy link
Copy Markdown
Contributor

@ssiegelx ssiegelx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this cleans up delay.py nicely and makes it quite a bit easier to follow.

I’m not able to track all of the changes, so it would be good to test and confirm that we get the same results as before.

I'm also unsure about the dependency on JAX. What kind of speed-up did you see in the Gibbs sampler on Cedar? (I think you mentioned a 1.5x improvement on your laptop.) Maybe @Arnab-half-blood-prince can weigh in on whether this is a bottleneck in any of his analyses. I believe we’re using the NRML in the daily pipeline?

One other thing I noticed: as currently written, the Gibbs sampler doesn’t seem to be using its rng argument at all, instead generating random seeds using JAX within delay_power_spectrum_gibbs. This means we lose the ability to track the exact random seed used to generate the spectrum and reproduce the results. In fact, the saved seed in the file would be incorrect. Would there be a way to create a JAX subclass of RandomTask to handle this?

Comment thread draco/analysis/delay.py Outdated
@ljgray
Copy link
Copy Markdown
Contributor Author

ljgray commented Mar 21, 2025

Good point about the random seed. I think I could come up with a way to work around this, but it would also be easy to revert and remove jax, if that performance improvement isn't significant enough to warrant the extra dependency and complexity

@ketiltrout
Copy link
Copy Markdown
Member

If we don't want to revert it completely, how much work would it be to split the JAX stuff out into its own PR?

@Arnab-half-blood-prince
Copy link
Copy Markdown
Contributor

@ssiegelx The Gibbs sampler is used only once for the data, and then we use that Power spectrum, fed into weiner filter and then generate delay spectrum for data, 21cm and noise. So, the estimation of PS is only once. All the delay spectrum generation, with Wiener filter is super fast, so not a bottleneck in any analysis. But, I think, we should check this with real data and see is there any difference between earlier version with this new one.

@ljgray
Copy link
Copy Markdown
Contributor Author

ljgray commented Mar 21, 2025

If we don't want to revert it completely, how much work would it be to split the JAX stuff out into its own PR?

I made it its own commit, so it should be easy. I think I'll do that

@ljgray
Copy link
Copy Markdown
Contributor Author

ljgray commented Mar 23, 2025

I've moved the jax changes into #336, which will be a WIP for now.

@ljgray
Copy link
Copy Markdown
Contributor Author

ljgray commented Mar 24, 2025

Now that the jax changes have been moved and we've checked to confirm that results are unchanged, I think this should be ok to merge. @ssiegelx @Arnab-half-blood-prince Do you have any issues with this being merged now?

Copy link
Copy Markdown
Member

@ketiltrout ketiltrout left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does seem to improve things.

I also appreciate the number of comments you're adding to the code.

Comment thread draco/analysis/delay.py Outdated
ljgray added 5 commits March 25, 2025 11:24
This refactor does not change the base delay transform, but it
aims to separate code which deals with input and output container
logic from code which deals with different estimators. It also
tries to better separate direct delay transforms, which can also
return a power spectrum, from more complicated algorithms used to
directly compute a power spectrum.
- Switch to a matern 1.5 prior, which shows faster convergance and
better behaviour around large deltas.
- Scale the initial fourier transform guess to avoid local minima
near initial guess.
- Enforce 64-bit precision in the nrml estimator for numerical stability.
- Move gaussian process prior kernels to separate module
- Remove 'SmoothnessRegulariser' as it is implemented in gaussian prior
@ljgray ljgray merged commit 43e5f04 into master Mar 25, 2025
4 checks passed
@ljgray ljgray deleted the ljg/delay_perf branch March 25, 2025 18:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants