Replies: 5 comments 2 replies
-
Let me try to understand jax more carefully. The There are a class of PRNGs that have a number of nice properties:
I'm not exactly sure how the splitting works, but I can ignore those details for the moment. What this essentially means is that the sample space In Jax, suppose I wanted to create two independent uniforms, I would do: def X(key):
k1, k2 = split(key)
return uniform(k1)
def Y(key):
k1, k2 = split(key)
return uniform(k2) |
Beta Was this translation helpful? Give feedback.
-
After thinking I have come to a few conclusions. First, is to clear up some confusion about what Jax is doing with Essentially, in Jax, the Omega space is the set of 64 bit integers. Splitting, takes some What do we actually need? I think there a few things.
function f(ω)
a = (1 ~ Normal(0, 1))(ω)
end This random variable could clash with some other random variable in the same file, another file, or some other persons module. That's bad for composability. Fundamentally, if we have an addressing system, then there needs to be some kind of mechanism by which the addresses I choose are not clashing with the ones you choose. One idea that follows from this is to then using the addressing scheme that Julia already provides: function f(ω)
is = ids(f)
a = (is[1] ~ Normal(0, 1))(ω)
end here Would there be any weird consequences of this? Whether you decide to inline a function or not could change the ids. But I don't think there's any observable effect of this. I can't think of any... A second requirement is that it not be so cumbersome to write. If we were using Cassette, you can imagine tracking the current function in the context, and when when someone does |
Beta Was this translation helpful? Give feedback.
-
Another thought regards projection. This is somewhat inspired by the Jax splittable-rng approach. I think I've been entertaining two separate ideas of a projection which are not quite the same:
Example function f(ω)
a = (1 ~ StdNormal)(ω)
end This, currently will be equivalent to The idea with a projection is that if Does this buy us anything?
function f(ω)
ω1, ω2 = split(ω)
a = unfiorm(ω1)
b = uniform(ω2)
end
This projection is what I was doing before but:
|
Beta Was this translation helpful? Give feedback.
-
Another axis is the semantics of classes / places and so on I've wavered between two different approaches. Approach 1Here, things are simple:
The sample space is essentially an infinite sequence of primitives, The main nice thing is that its very simple, with straight forward semantics. It also allows us to easily model plates. The main downside is about composition. If I want to make a new class, essentially a plate, anything that I want to vary (include within the plate) has to already be a class. In other words. I can't make independent copies of a random variable. Approach 2In approach two:
This is very flexible. We'd have to implement it with the projection mechanism What worries me is that:
So does the
The problems this brings however (and it's all coming back to me now) are in things like interventions. What does
So we can't intervene on
The same problem arises in memoization and around proposals. |
Beta Was this translation helpful? Give feedback.
-
Another question Should the plate ids be the same kind of ids as the projection ids? I wrote this linear regression model xs = rand(N)
plate(i, ω) = linear_model(xs[i], M(ω), C(ω)) + ((i, 1) ~ Normal(0, 0.1))(ω)
Y = (1:N) .~ plate The issue is that On the other hand it seems like the indices for a plate should be integers or cartesian indexes and not tied to indices we use for projection. Another way to think about it is that for the primitives, we need some indexing system to determine one from another, which is currently Could we? If we interpret the final index as a mix Options 2 plate(i, ω) = linear_model(xs[i], M(ω), C(ω)) + ((i, 1) ~ Normal(0, 0.1))(ω) Option 3 Option 2 seems most appealing |
Beta Was this translation helpful? Give feedback.
-
Riffing off #122 and #120 and #6
This is to discuss the nature of independence, conditional independence and related matters.
There have been many iterations of the semantics. Currently, it is as follows:
StdUnif1
,StdUnif2
, ...,StdNormal1
,StdNormal2
...class(id, \omega)
\omega -> class(id, \omega)
normal(\mu, \sigma) = id, \omega -> StdNormal(id, \omega) * \sigma + \mu
X = normal(2, 3)
thenX
is a class of normally dist randvar with mean 2 and stdv 3.Y = normal(10, 3)
, it's not the case thatX
andY
are independent classes. The independence only follows between elements inX
(and betweenY
)Y_ = normal(10, 3)
,Y(id, \omega) = Y_(pair(id,2), \omega)
id1 <| class = id2, \omega -> class(pair(id1, id2), \omega)
Y = 2 <| normal(10, 3)
ciid
/~
:id ~ class = \omega -> class(id, \omega)
id ~ X = \omega -> X(proj(\omega, id)
whereproj(\omega, id) = id2 -> \omega(pair(id, id2))
The Jax/Dex approach is a little different:
bernoulli(key, p)
, similar to as described abovesplit_key
to turn a key into two keysThe differences between Omega
\omega
So rather than say you want the 10th element you'd say you want 10 copies and take the 10th one.
I'm not sure if that is a fundamental difference.
Beta Was this translation helpful? Give feedback.
All reactions