Skip to content

Tools for a[e] PPL in Aesara.

License

Notifications You must be signed in to change notification settings

brandonwillard/aeppl

 
 

Repository files navigation

AePPL Logo

AePPL

Pypi Downloads Contributors
Gitter Discord Twitter

Aeppl provides tools for a[e]PPL written in Aesara.

Build arbitrarily complex probabilistic models. If it is mathematically defined, AePPL will support it.

FeaturesGet startedInstallGet helpContribute

Features

  • Convert graphs containing Aesara RandomVariables into joint log-probability graphs
  • Transforms for RandomVariables that map constrained support spaces to unconstrained spaces (e.g. the extended real numbers), and a rewrite that automatically applies these transformations throughout a graph
  • Tools for traversing and transforming graphs containing RandomVariables
  • RandomVariable-aware pretty printing and LaTeX output

Get started

Using aeppl, one can create a joint log-density graph from a graph containing Aesara RandomVariables:

import aesara
from aesara import tensor as at

from aeppl import joint_logprob, pprint

srng = at.random.RandomStream()

# A simple scale mixture model
S_rv = srng.invgamma(0.5, 0.5)
Y_rv = srng.normal(0.0, at.sqrt(S_rv))

# Compute the joint log-probability
logprob, (y, s) = joint_logprob(Y_rv, S_rv)

Log-density graphs are standard Aesara graphs, so we can compute compile them to compute values:

logprob_fn = aesara.function([y, s], logprob)

logprob_fn(-0.5, 1.0)
# array(-2.46287705)

AePPL provides utilities to pretty-print the log-density graphs:

from aeppl import pprint, latex_pprint


# Print the original graph
print(pprint(Y_rv))
# b ~ invgamma(0.5, 0.5) in R, a ~ N(0.0, sqrt(b)**2) in R
# a

print(latex_pprint(Y_rv))
# \begin{equation}
#   \begin{gathered}
#     b \sim \operatorname{invgamma}\left(0.5, 0.5\right)\,  \in \mathbb{R}
#     \\
#     a \sim \operatorname{N}\left(0.0, {\sqrt{b}}^{2}\right)\,  \in \mathbb{R}
#   \end{gathered}
#   \\
#   a
# \end{equation}

# Simplify the graph so that it's easier to read
from aesara.graph.rewriting.utils import rewrite_graph
from aesara.tensor.rewriting.basic import topo_constant_folding


logprob = rewrite_graph(logprob, custom_rewrite=topo_constant_folding)


print(pprint(logprob))
# s in R, y in R
# (switch(s >= 0.0,
#         ((-0.9189385175704956 +
#           switch(s == 0, -inf, (-1.5 * log(s)))) - (0.5 / s)),
#         -inf) +
#  ((-0.9189385332046727 + (-0.5 * ((y / sqrt(s)) ** 2))) - log(sqrt(s))))

Joint log-densities can be computed for some terms that are derived from RandomVariables, as well:

# Create a switching model from a Bernoulli distributed index
Z_rv = srng.normal([-100, 100], 1.0, name="Z")
I_rv = srng.bernoulli(0.5, name="I")

M_rv = Z_rv[I_rv]
M_rv.name = "M"

# Compute the joint log-probability for the mixture
logprob, (m, z, i) = joint_logprob(M_rv, Z_rv, I_rv)


logprob = rewrite_graph(logprob, custom_rewrite=topo_constant_folding)

print(pprint(logprob))
# i in Z, m in R, a in Z
# (switch((0 <= i and i <= 1), -0.6931472, -inf) +
#  ((-0.9189385332046727 + (-0.5 * (((m - [-100  100][a]) / [1. 1.][a]) ** 2))) -
#   log([1. 1.][a])))

Take a look at the documentation for more examples.

Install

The latest release of aeppl can be installed from PyPI using pip:

pip install aeppl

Or via conda-forge:

conda install -c conda-forge aeppl

The nightly (bleeding edge) version of aeppl can be installed using pip:

pip install aeppl-nightly

Get help

Report bugs by opening an issue. If you have a question regarding the usage of AePPL, start a discussion or visit our Discord server and Gitter room chats.

Contribute

AePPL welcomes contributions. To start contributing, take a look at the open issues.

If you want to implement a new feature, open a discussion or come chat with us on Discord or Gitter.

About

Tools for a[e] PPL in Aesara.

Resources

License

Security policy

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.4%
  • Makefile 0.6%