Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax example not working #44

Open
dkirkby opened this issue Sep 14, 2020 · 7 comments
Open

Jax example not working #44

dkirkby opened this issue Sep 14, 2020 · 7 comments

Comments

@dkirkby
Copy link
Member

dkirkby commented Sep 14, 2020

I am trying to run the jax example via:

python bin/challenge.py example/example_jax.yaml

However, this fails when calculating the score. I put the full traceback below, but I think is in the jax-cosmo chi calculation:

  File "/Users/david/anaconda3/envs/jax/lib/python3.8/site-packages/jax_cosmo/background.py", line 248, in radial_comoving_distance
    return np.clip(interp(a, cache["a"], cache["chi"]), 0.0)

Full traceback (note that the index arrray is float for some reason):

Executing:  RandomForest riz {'bins': 3, 'colors': True, 'errors': False}
{'bins': 3, 'colors': True, 'errors': False} ['bins']
Initializing classifier...
Training...
Finding bins for training data
Fitting classifier
Applying...
Getting metric...
/Users/david/anaconda3/envs/jax/lib/python3.8/site-packages/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
/Users/david/anaconda3/envs/jax/lib/python3.8/site-packages/jax/lax/lax.py:6193: UserWarning: Explicitly requested dtype <class 'jax.numpy.lax_numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  warnings.warn(msg.format(dtype, fun_name , truncated_dtype))
Traceback (most recent call last):
  File "bin/challenge.py", line 111, in <module>
    main()
  File "/Users/david/.local/lib/python3.8/site-packages/click/core.py", line 829, in __call__
    return self.main(*args, **kwargs)
  File "/Users/david/.local/lib/python3.8/site-packages/click/core.py", line 782, in main
    rv = self.invoke(ctx)
  File "/Users/david/.local/lib/python3.8/site-packages/click/core.py", line 1066, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/Users/david/.local/lib/python3.8/site-packages/click/core.py", line 610, in invoke
    return callback(*args, **kwargs)
  File "bin/challenge.py", line 65, in main
    scores = run_one(classifier_name, bands, settings,
  File "bin/challenge.py", line 101, in run_one
    scores = metrics_fn(results, valid_z, metrics=metrics)
  File "/Users/david/Cosmo/DESC/code/zotbin/tomo_challenge/jax_metrics.py", line 227, in compute_scores
    scores['SNR_'+what] = float(compute_snr_score(tomo_bin, z, what=what, binned_nz=True))
  File "/Users/david/Cosmo/DESC/code/zotbin/tomo_challenge/jax_metrics.py", line 125, in compute_snr_score
    return snr_fn(probes, ell)
  File "/Users/david/Cosmo/DESC/code/zotbin/tomo_challenge/jax_metrics.py", line 114, in snr_fn
    mu, C = jc.angular_cl.gaussian_cl_covariance_and_mean(cosmo, ell, probes,
  File "/Users/david/anaconda3/envs/jax/lib/python3.8/site-packages/jax_cosmo/angular_cl.py", line 193, in gaussian_cl_covariance_and_mean
    cl_signal = angular_cl(
  File "/Users/david/anaconda3/envs/jax/lib/python3.8/site-packages/jax_cosmo/angular_cl.py", line 104, in angular_cl
    return cl(ell)
  File "/Users/david/anaconda3/envs/jax/lib/python3.8/site-packages/jax_cosmo/angular_cl.py", line 102, in cl
    return simps(integrand, z2a(zmax), 1.0, 512) / const.c ** 2
  File "/Users/david/anaconda3/envs/jax/lib/python3.8/site-packages/jax_cosmo/scipy/integrate.py", line 198, in simps
    y = f(x)
  File "/Users/david/anaconda3/envs/jax/lib/python3.8/site-packages/jax_cosmo/angular_cl.py", line 77, in integrand
    chi = bkgrd.radial_comoving_distance(cosmo, a)
  File "/Users/david/anaconda3/envs/jax/lib/python3.8/site-packages/jax_cosmo/background.py", line 248, in radial_comoving_distance
    return np.clip(interp(a, cache["a"], cache["chi"]), 0.0)
  File "/Users/david/anaconda3/envs/jax/lib/python3.8/site-packages/jax_cosmo/scipy/interpolate.py", line 36, in interp
    a = (fp[ind + np.copysign(1, s)] - fp[ind]) / (
  File "/Users/david/anaconda3/envs/jax/lib/python3.8/site-packages/jax/numpy/lax_numpy.py", line 3648, in _rewriting_take
    return _gather(arr, treedef, static_idx, dynamic_idx)
  File "/Users/david/anaconda3/envs/jax/lib/python3.8/site-packages/jax/numpy/lax_numpy.py", line 3655, in _gather
    indexer = _index_to_gather(shape(arr), idx)  # shared with _scatter_update
  File "/Users/david/anaconda3/envs/jax/lib/python3.8/site-packages/jax/numpy/lax_numpy.py", line 3910, in _index_to_gather
    raise TypeError(msg.format(abstract_i.dtype.name, idx_pos, i))
jax.traceback_util.FilteredStackTrace: TypeError: Indexer must have integer or boolean type, got indexer with type float32 at position 0, indexer value Traced<ShapedArray(float32[])>with<BatchTrace(level=1/1)>
  with val = DeviceArray([195., 195., 197., 197., 196., 198., 198., 197., 197., 199.,
                          199., 198., 198., 200., 200., 199., 199., 201., 201., 200.,
                          200., 202., 202., 201., 201., 203., 203., 202., 202., 204.,
                          204., 203., 203., 205., 205., 204., 204., 204., 206., 206.,
                          205., 205., 207., 207., 207., 206., 206., 208., 208., 207.,
                          207., 207., 209., 209., 208., 208., 208., 210., 210., 209.,
                          209., 209., 211., 211., 210., 210., 210., 212., 212., 211.,
                          211., 211., 213., 213., 213., 212., 212., 212., 214., 214.,
                          213., 213., 213., 215., 215., 215., 214., 214., 214., 216.,
                          216., 216., 215., 215., 215., 217., 217., 217., 216., 216.,
                          216., 218., 218., 218., 217., 217., 217., 219., 219., 219.,
                          219., 218., 218., 218., 220., 220., 220., 219., 219., 219.,
                          221., 221., 221., 221., 220., 220., 220., 222., 222., 222.,
                          222., 221., 221., 221., 223., 223., 223., 223., 222., 222.,
                          222., 224., 224., 224., 224., 223., 223., 223., 223., 225.,
                          225., 225., 225., 224., 224., 224., 226., 226., 226., 226.,
                          225., 225., 225., 225., 227., 227., 227., 227., 226., 226.,
                          226., 226., 228., 228., 228., 228., 227., 227., 227., 227.,
                          229., 229., 229., 229., 229., 228., 228., 228., 228., 230.,
                          230., 230., 230., 229., 229., 229., 229., 229., 231., 231.,
                          231., 231., 230., 230., 230., 230., 230., 232., 232., 232.,
                          232., 231., 231., 231., 231., 231., 233., 233., 233., 233.,
                          232., 232., 232., 232., 232., 234., 234., 234., 234., 234.,
                          233., 233., 233., 233., 233., 235., 235., 235., 235., 235.,
                          234., 234., 234., 234., 234., 236., 236., 236., 236., 236.,
                          235., 235., 235., 235., 235., 237., 237., 237., 237., 237.,
                          236., 236., 236., 236., 236., 236., 238., 238., 238., 238.,
                          238., 237., 237., 237., 237., 237., 239., 239., 239., 239.,
                          239., 239., 238., 238., 238., 238., 238., 240., 240., 240.,
                          240., 240., 240., 239., 239., 239., 239., 239., 239., 241.,
                          241., 241., 241., 241., 241., 240., 240., 240., 240., 240.,
                          242., 242., 242., 242., 242., 242., 242., 241., 241., 241.,
                          241., 241., 241., 243., 243., 243., 243., 243., 243., 242.,
                          242., 242., 242., 242., 242., 244., 244., 244., 244., 244.,
                          244., 243., 243., 243., 243., 243., 243., 243., 245., 245.,
                          245., 245., 245., 245., 244., 244., 244., 244., 244., 244.,
                          244., 246., 246., 246., 246., 246., 246., 245., 245., 245.,
                          245., 245., 245., 245., 247., 247., 247., 247., 247., 247.,
                          247., 246., 246., 246., 246., 246., 246., 246., 248., 248.,
                          248., 248., 248., 248., 248., 247., 247., 247., 247., 247.,
                          247., 247., 249., 249., 249., 249., 249., 249., 249., 248.,
                          248., 248., 248., 248., 248., 248., 250., 250., 250., 250.,
                          250., 250., 250., 250., 249., 249., 249., 249., 249., 249.,
                          249., 251., 251., 251., 251., 251., 251., 251., 251., 250.,
                          250., 250., 250., 250., 250., 250., 250., 252., 252., 252.,
                          252., 252., 252., 252., 252., 251., 251., 251., 251., 251.,
                          251., 251., 251., 253., 253., 253., 253., 253., 253., 253.,
                          253., 252., 252., 252., 252., 252., 252., 252., 252., 254.,
                          254., 254., 254., 254., 254., 254., 254., 253., 253., 253.,
                          253., 253., 253., 253., 253., 255., 255., 255., 255., 255.,
                          255., 255., 255., 255., 255., 255., 255., 255., 255., 255.,
                          255., 255., 255.], dtype=float32)
       batch_dim = 0
@dkirkby
Copy link
Member Author

dkirkby commented Sep 14, 2020

I can reproduce a similar error directly with jax-cosmo:

import jax_cosmo
import jax.numpy as jnp
model = jax_cosmo.parameters.Planck15()
jax_cosmo.background.radial_comoving_distance(model, jnp.linspace(0, 3, 10))

Traceback:

FilteredStackTrace                        Traceback (most recent call last)
<ipython-input-5-6026efa086df> in <module>
----> 1 jax_cosmo.background.radial_comoving_distance(model, jnp.linspace(0, 3, 10))

~/anaconda3/envs/jax/lib/python3.8/site-packages/jax_cosmo/background.py in radial_comoving_distance(cosmo, a, log10_amin, steps)
    247     # Return the results as an interpolation of the table
--> 248     return np.clip(interp(a, cache["a"], cache["chi"]), 0.0)
    249

~/anaconda3/envs/jax/lib/python3.8/site-packages/jax_cosmo/scipy/interpolate.py in interp(x, xp, fp)
     35     s = np.sign(np.clip(x, xp[1], xp[-2]) - xi).astype(np.int64)
---> 36     a = (fp[ind + np.copysign(1, s)] - fp[ind]) / (
     37         xp[ind + np.copysign(1, s)] - xp[ind]

~/anaconda3/envs/jax/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _rewriting_take(arr, idx)
   3647   treedef, static_idx, dynamic_idx = _split_index_for_jit(idx)
-> 3648   return _gather(arr, treedef, static_idx, dynamic_idx)
   3649

~/anaconda3/envs/jax/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _gather(arr, treedef, static_idx, dynamic_idx)
   3654   idx = _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
-> 3655   indexer = _index_to_gather(shape(arr), idx)  # shared with _scatter_update
   3656   y = arr

~/anaconda3/envs/jax/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _index_to_gather(x_shape, idx)
   3909                "with type {} at position {}, indexer value {}")
-> 3910         raise TypeError(msg.format(abstract_i.dtype.name, idx_pos, i))
   3911

FilteredStackTrace: TypeError: Indexer must have integer or boolean type, got indexer with type float32 at position 0, indexer value Traced<ShapedArray(float32[])>with<BatchTrace(level=0/0)>
  with val = DeviceArray([  2., 215., 241., 255., 255., 255., 255., 255., 255., 255.],            dtype=float32)
       batch_dim = 0

@dkirkby
Copy link
Member Author

dkirkby commented Sep 14, 2020

I am using the latest (pip install --upgrade) jax and jax-cosmo.

@EiffL
Copy link
Member

EiffL commented Sep 14, 2020

Arf.... this again. Happened in the past when JAX changed internally the way they handle some indices. Are you running on the lastest JAX version?

@dkirkby
Copy link
Member Author

dkirkby commented Sep 14, 2020

Yes, the latest available via pypi which is 0.1.76.

@EiffL
Copy link
Member

EiffL commented Sep 14, 2020

ok, I'm updating my JAX version and trying to reproduce that

@EiffL
Copy link
Member

EiffL commented Sep 14, 2020

Sorry took a while to reinstall everything on my new work machine. The culprit is that the sign function in JAX is no longer preserving the int type of the input :-|
I added a quick fix for that on the jax-cosmo master:
DifferentiableUniverseInitiative/jax_cosmo@842225f
if you grab the master branch of jax-cosmo, everything should work

@dkirkby
Copy link
Member Author

dkirkby commented Sep 14, 2020

Yes, it works now, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants