### Bug Description This is a part of the issues reported in https://github.com/pyro-ppl/numpyro/issues/1981. Running the following test will raise an error/xfail. ### Steps to Reproduce ``` JAX_CHECK_TRACER_LEAKS=1 pytest -vs test/contrib/test_infer_discrete.py::test_scan_hmm_smoke ``` ### Expected Behavior The test should pass.