Description
Before
Minibatches are tricky to work with, and they also rely on carefully constructed random graph which does not work as expected in particular scenarios. It was supposed to work with ADVI graph and when used outside the scope it misbehaves. Couple of immediate issues I know about:
- Can't be used in leapfrog step iteration, random state will change between evaluation
- Complicates the internals of pymc
- Consumes a lot of memory since all the dataset is stored in memory
- Does not scale well since it uses advanced indexing or a random draw
With all that in an attempt to improve and ADVI it does very poor job. Not scalable, fragile.
After
What can be done differently, is using minibatches in a traditional way, like they are used in e.g. Pytorch. There is a function producing a new batch that is passed to the loss function. In our case we can use callbacks that are called after every ADVI iteration and this callback will reset the shared variable state, making the approach much more scalable and less hacky.
with pm.Model():
a = pm.Normal("a", total_size=1000000, observed=data) # apply scaling
minibatch = MinibatchCallback(iterable, [data])
fit = pm.fit(callbacks=[minibatch])
Context for the issue:
No response