-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Incorporating AePPL into GaussianRandomWalk
#5814
Conversation
Codecov Report
@@ Coverage Diff @@
## main #5814 +/- ##
==========================================
- Coverage 89.26% 83.01% -6.25%
==========================================
Files 72 73 +1
Lines 12890 13261 +371
==========================================
- Hits 11506 11009 -497
- Misses 1384 2252 +868
|
Nice! However, this isn't the same thing, it's a different parameterization of the same thing. Like centered vs non-centered. So I don't think we should remove the previous one as people might be using it and then changing the parameterization might screw with the inference. However, we could add it either as an option or a separate dist. |
What do you mean? This should be completely equivalent to what we had before. |
You might be thinking of the latent parametrizations but that was not what GRW did before (as that wouldn't have been a vanilla distribution) Something like that would just require implementing a default (diff) transform which would definitely be helpful for random walks. Alternatively it would be useful to allow different implementations between unobserved and observed variables. Either eagerly at creation time or, even better, lazily when compiling the logp, but the latter requires some more work (make sure we have the right deterministics). |
Isn't this changing the previous parameterization: |
Where would the parametrization be changed? The random generator method is exactly the same as before if you check what was going on in the old pymc/pymc/tests/test_distributions.py Lines 2609 to 2624 in 5703a9d
The expression Now for NUTS sampling (when we have unobserved GRW) it would probably be better if we were sampling in the latent space For that we would need to either: 1) Add a new type of transform or 2) Have a way to create different graphs depending on whether the variables are observed or not. |
I see, so there is no change from this implementation to the previous in v4, but it is different compared to v3, right? |
No, V3 did the same as well. The logp is written a bit different, but it's equivalent to the manual logp we had here in V4 up to this PR: pymc/pymc3/distributions/timeseries.py Lines 227 to 246 in ed74406
And there was no special transformation either. |
Oh right, I read you previous comment again. It's always been |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @ricardoV94 about a question regarding retrieving dist_params
from the random graph
cc @nicospinu for GSOC reference This PR will end up being a good reference for a side by side comparison of how time series work when implemented in PyMC and how they'll be implemented if more integrated with AePPL |
To my knowledge, three sets of tests are failing and they are definitely worth looking into.
Would these issues be important to address? To my knowledge, some of these, especially static shape inference, would take substantial work to be included in Aesara. Happy to hear thoughts about this Uncommenting the error pertaining to unexpected rv nodes here still yields many errors, so this would be important to look into |
Can you expand?
Yes, that behaves differently, but I don't think it poses a significant problem. We can just change the expectation of the tests.
That should work the same way as before. We do that in AR which is also a symbolic dist. It happens before any call to rv_op. It does not interfere or depend on Aeppl |
Some tests that checked for shape inference were failing. They seem to have gone away whether this is related or not to the implemented
In the tests, I summed up the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe that I'm getting closer, minus the implementation of change_size
.
Converted this PR back to draft. These lines probably warrant discussion as the implementation of the
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some notes from the meeting on Friday, June 10, 2022
294f343
to
d483ec2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revisiting this PR after a hiatus. I'm not sure if the presence of random variables in the logp graph is due to retrieving distribution/rv static shapes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this be the appropriate hackish fix? 😅 It seems that the presence of random variables in the logp graph is not an AePPL problem but rather PyMC-related
Still needs a test for moments |
Actually, I'm realizing that we need a dispatch for the with pm.Model() as model:
mu = pm.Normal("mu", 2, 3)
sigma = pm.Gamma("sigma", 1, 1)
grw = pm.GaussianRandomWalk(name="grw", mu=mu, sigma=sigma, init_dist=pm.StudentT.dist(5), steps=10) |
We can handle it in the same dispatch moment function, just have an if branch for that. It's the same discussion we had before of whether we specialize or create more generalized moment functions. |
I realized that I could just do I just added tests for moments. I believe that this PR is ready to be merged, but moment tests are only passing locally so far and will be with PR 151 in AePPL merged. |
With #5955 soon to be merged, I believe that this PR would be good to go? This PR adds a hackish tweak that we don't check for random variables in the ancestors of a |
I don't think we should proceed with the hack at all. It's mixing logic that should be separated. We should override aeppl join logprob instead as we have more strict requirements than they do, and do the constant fold of the shapes there. |
Oh, my bad, and sounds good. It was recommended by @brandonwillard to edit a graph rewrite or register another one in the database for shape inference. |
Co-authored-by: Ricardo Vieira <[email protected]> Co-authored-by: lucianopaz <[email protected]>
), | ||
], | ||
) | ||
def test_gaussianrandomwalk(mu, sigma, init_dist, steps, size, expected): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is already a test of moments in test_distribution_timeseries
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init_dist=init) | ||
assert tuple(grw.owner.inputs[-2].shape.eval()) == () | ||
assert tuple(get_init_dist(grw).shape.eval()) == () |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use tag
|
||
def test_shape_ellipsis(self): | ||
grw = pm.GaussianRandomWalk.dist( | ||
mu=0, sigma=1, steps=5, init_dist=pm.Normal.dist(), shape=(3, ...) | ||
) | ||
assert tuple(grw.shape.eval()) == (3, 6) | ||
assert tuple(grw.owner.inputs[-2].shape.eval()) == (3,) | ||
assert tuple(grw.owner.inputs[0].owner.inputs[1].owner.inputs[0].shape.eval()) == (3,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use tag
10, | ||
(4, 2), | ||
np.full((4, 2, 11), np.vstack((np.arange(11), -np.arange(11)))), | ||
), | ||
], | ||
) | ||
def test_moment(self, mu, sigma, init_dist, steps, size, expected): | ||
def test_moment(self, mu, sigma, init, steps, size, expected): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No reason to rename init_dist
to init
@larryshamalama do you feel like picking this up after #6072? I think that makes our lives easier here |
Absolutely |
Actually it seems like #6072 will have to include this one by necessity |
Done in #6072 |
Closes #5762.
This PR removes
GaussianRandomWalkRV
altogether by defining anrv_op
as a cumulative sum of distributions. The most recent version of AePPL, i.e. 0.0.31, is now able to retrieve the appropriate logp graph from aCumOp
.This is a WIP for now as I figure out how to use AePPL...