You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
PyroModule and PyroSample make it straightforward to compositionally specify probabilistic models with random parameters. However, PyroSample has a somewhat awkward interaction with pyro.plate:
To ensure loc and scale are sampled globally, it is necessary to access them outside the data plate as scale is in the above - inlining self.loc in the final line samples a different loc for each datapoint. This behavior is unambiguous semantically, but it can cause confusion in more complex models and require lots of ugly boilerplate code in the model that manually samples random parameters of submodules in the correct plate context.
For example, in the below code the intuitive behavior for Model.linear is clearly for linear.weight to be sampled outside of the data plate, but because self.linear is invoked for the first time inside the plate, there will be separate random copies of linear.weight for each plate slice:
However, it would not be correct to simply ignore all plates when executing PyroSamples - in this example, we might want to use a multi-sample ELBO estimator in inferring self.linear.weight (e.g. pyro.infer.Trace_ELBO(num_particles=10, vectorize_particles=True)), which is implemented with another plate that should not be ignored.
Proposed fix
It would be nice to have a feature that enabled the intuitive behavior in the second example above without breaking backwards compatibility with PyroSample's existing semantics or its correctness in the presence of enclosing plates like that introduced by the multi-sample ELBO.
This could potentially be achieved with a new handler PyroSamplePlateScope such that PyroSample statements executed inside its context are only modified by plates entered outside of it, while ordinary pyro.sample statements are unaffected and behave in the usual way:
Personally, for me it is more intuitive to treat PyroSample the same as pyro.sample and not inline it. But I don't use PyroSample much and I can see that there might be the convenience of inlining it if used a lot.
As a consideration, should plate scoping be implemented as a context manager like in PyroSamplePlateScope or done per individual PyroSample (e.g. through infer={"ignored_plates": ...} which would also work with pyro.sample)? For example if you want self.loc to be sampled inside of the data plate and self.scale sampled outside of the data plate:
classModel(pyro.nn.PyroModule):
@pyro.nn.PyroSampledefloc(self):
returnpyro.distributions.Normal(0, 1)
@pyro.nn.PyroSample(infer={"ignored_plates": ["data"]}) # new syntaxdefscale(self):
returnpyro.distributions.LogNormal(0, 1)
defforward(self, x_obs):
withpyro.plate("data", x_obs.shape[0], dim=-1):
returnpyro.sample("x", pyro.distributions.Normal(self.loc, self.scale), obs=x_obs) # self.loc is local and self.scale is global
(one drawback of this approach is that ignored_plates is not in the forward method and hidden elsewhere which can make it harder to read the code)
Problem
PyroModule
andPyroSample
make it straightforward to compositionally specify probabilistic models with random parameters. However,PyroSample
has a somewhat awkward interaction withpyro.plate
:To ensure
loc
andscale
are sampled globally, it is necessary to access them outside thedata
plate asscale
is in the above - inliningself.loc
in the final line samples a differentloc
for each datapoint. This behavior is unambiguous semantically, but it can cause confusion in more complex models and require lots of ugly boilerplate code in the model that manually samples random parameters of submodules in the correct plate context.For example, in the below code the intuitive behavior for
Model.linear
is clearly forlinear.weight
to be sampled outside of thedata
plate, but becauseself.linear
is invoked for the first time inside the plate, there will be separate random copies oflinear.weight
for each plate slice:However, it would not be correct to simply ignore all plates when executing
PyroSample
s - in this example, we might want to use a multi-sample ELBO estimator in inferringself.linear.weight
(e.g.pyro.infer.Trace_ELBO(num_particles=10, vectorize_particles=True)
), which is implemented with anotherplate
that should not be ignored.Proposed fix
It would be nice to have a feature that enabled the intuitive behavior in the second example above without breaking backwards compatibility with
PyroSample
's existing semantics or its correctness in the presence of enclosing plates like that introduced by the multi-sample ELBO.This could potentially be achieved with a new handler
PyroSamplePlateScope
such thatPyroSample
statements executed inside its context are only modified by plates entered outside of it, while ordinarypyro.sample
statements are unaffected and behave in the usual way:The text was updated successfully, but these errors were encountered: