-
Notifications
You must be signed in to change notification settings - Fork 85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
toms748_scan doesn't work with JAX backend #2466
Comments
It seems that this can be fixed by explicitly converting |
First thing we'll need to do is understand why there aren't test failures Lines 26 to 57 in 64ab264
which will probably mean revisiting PR #1274. So writing a failing test would be a good start, so that a PR can make it pass. |
Ah, I didn't realize there was a test for this! Does that get run with all the backends? When I get a chance I can try running that locally too. |
Nope, which is likely which explains why it wasn't caught. (Adding the |
@matthewfeickert i saw the PR, and I think we need to swap the way we're approaching this. Here's my suggestion instead of type-casting - we need to add in shims across each lib and move some functions into our See this example: from functools import lru_cache
import time
import timeit
import jax.numpy as jnp
import jax
import tensorflow as tf
def slow(n):
time.sleep(1)
return n**2
fast = lru_cache(maxsize=None)(slow)
fast_jax = jax.jit(slow)
fast_tflow = tf.function(jit_compile=True)(slow)
value = 5
print('slow')
print(timeit.timeit(lambda: [slow(value), slow(value), slow(value), slow(value), slow(value)], number=1))
print('fast')
print(timeit.timeit(lambda: [fast(value), fast(value), fast(value), fast(value), fast(value)], number=1))
value = jnp.array(5)
print('slow, jax')
print(timeit.timeit(lambda: [slow(value), slow(value), slow(value), slow(value), slow(value)], number=1))
print('fast, jax')
print(timeit.timeit(lambda: [fast_jax(value), fast_jax(value), fast_jax(value), fast_jax(value), fast_jax(value)], number=1))
value = tf.convert_to_tensor(5)
print('slow, tensorflow')
print(timeit.timeit(lambda: [slow(value), slow(value), slow(value), slow(value), slow(value)], number=1))
print('fast, tensorflow')
print(timeit.timeit(lambda: [fast_tflow(value), fast_tflow(value), fast_tflow(value), fast_tflow(value), fast_tflow(value)], number=1)) which outputs
so we can definitely cache those values by JIT-ing for the toms748 scan here and that's probably what we want to do. My suggestion might be that we support |
Okay, sounds good. Let's start up a seperate series of PRs to do this. |
Summary
Hello; perhaps this is known but I thought I'd file a bug report just in case. I was testing the
upper_limits
API and discovered that the example given in the documentation doesn't seem to work with the JAX backend. It fails with a complaint about an unhashable array type (see the traceback). If I switch to the numpy backend, as shown in the documentation, it runs fine.I see this on both EL7 in an ATLAS environment (
StatAnalysis,0.3,latest
) and on my own desktop (Fedora 38); in both cases I have the same pyhf version (0.7.6) and I manually installedjax[CPU] == 0.4.26
on top of that.I should add that things work fine with JAX if I use the version of
upper_limits
where I pass in a range of mu values to scan-- so I guess maybe some extra type conversion is needed to go from the JAX array type to a list or something hashable?OS / Environment
Steps to Reproduce
Install pyhf and JAX through pip; then try to run the example in the documentation, but with the JAX backend instead of numpy:
File Upload (optional)
No response
Expected Results
Ideally the example would run without crashing (as it does with the numpy backend).
Actual Results
pyhf Version
Code of Conduct
The text was updated successfully, but these errors were encountered: