### Description This blogpost walks through the logic for 3 different examples: https://www.pymc-labs.com/blog-posts/jax-functions-in-pymc-3-quick-examples/ and shows the logic is always the same: 1. Wrap jitted forward pass in Op 2. Wrap jitted jvp (or vjp I can never remember) as a GradOp to provide gradient implementation 3. Dispatch unjitted versions of the two Ops for integration with `function(... , mode="JAX") Things that cannot be obtained automatically (or maybe they can?) and should be opt-in as in `@as_op`: 4. Input and outputs types 5. infer_shape