Skip to content

ENH: Rework Minibatch functionality #7496

Open
@ferrine

Description

@ferrine

Before

Minibatches are tricky to work with, and they also rely on carefully constructed random graph which does not work as expected in particular scenarios. It was supposed to work with ADVI graph and when used outside the scope it misbehaves. Couple of immediate issues I know about:

  • Can't be used in leapfrog step iteration, random state will change between evaluation
  • Complicates the internals of pymc
  • Consumes a lot of memory since all the dataset is stored in memory
  • Does not scale well since it uses advanced indexing or a random draw

With all that in an attempt to improve and ADVI it does very poor job. Not scalable, fragile.

After

What can be done differently, is using minibatches in a traditional way, like they are used in e.g. Pytorch. There is a function producing a new batch that is passed to the loss function. In our case we can use callbacks that are called after every ADVI iteration and this callback will reset the shared variable state, making the approach much more scalable and less hacky.

with pm.Model():
    a = pm.Normal("a", total_size=1000000, observed=data) # apply scaling
    minibatch = MinibatchCallback(iterable, [data])
    fit = pm.fit(callbacks=[minibatch])

Context for the issue:

No response

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions