Skip to content

Commit

Permalink
add poisson.rvs (#94)
Browse files Browse the repository at this point in the history
Also makes norm.rvs faster and simpler
  • Loading branch information
HDembinski authored Feb 9, 2024
1 parent 3675374 commit 51b6282
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 10 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.11"]
python-version: ["3.8", "3.12"]

steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- run: python -m pip install --upgrade pip
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ jobs:
upload:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4

- uses: actions/setup-python@v2
- uses: actions/setup-python@v5
with:
python-version: '3.9'

Expand All @@ -22,7 +22,7 @@ jobs:
- run: python -m pip install --force-reinstall dist/*.tar.gz
- run: python -m pytest

- uses: pypa/gh-action-pypi-publish@master
- uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
password: ${{secrets.PYPI_TOKEN}}
5 changes: 2 additions & 3 deletions src/numba_stats/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,10 @@ def _ppf(p, loc, scale):
return r


@_rvs_jit(2, cache=False)
@_rvs_jit(2)
def _rvs(loc, scale, size, random_state):
_seed(random_state)
p = np.random.uniform(0, 1, size)
return _ppf(p, loc, scale)
return np.random.normal(loc, scale, size)


_generate_wrappers(globals())
14 changes: 13 additions & 1 deletion src/numba_stats/poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import numpy as np
from ._special import gammaincc as _gammaincc
from math import lgamma as _lgamma
from ._util import _jit, _generate_wrappers, _prange
from ._util import _jit, _generate_wrappers, _prange, _seed
import numba as nb

_doc_par = """
mu : float
Expand Down Expand Up @@ -43,4 +44,15 @@ def _cdf(k, mu):
return r


@nb.njit(
nb.int64[:](nb.float32, nb.uint64, nb.optional(nb.uint64)),
cache=True,
inline="always",
error_model="numpy",
)
def _rvs(mu, size, random_state):
_seed(random_state)
return np.random.poisson(mu, size)


_generate_wrappers(globals())
13 changes: 13 additions & 0 deletions tests/test_poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from numba_stats import poisson
import scipy.stats as sc
import pytest
import numba as nb


@pytest.mark.parametrize("mu", np.linspace(0, 3, 5))
Expand All @@ -18,3 +19,15 @@ def test_cdf(mu):
got = poisson.cdf(k, mu)
expected = sc.poisson.cdf(k, mu)
np.testing.assert_allclose(got, expected)


@pytest.mark.parametrize("mu", np.linspace(0, 3, 5))
def test_rvs(mu):
got = poisson.rvs(mu, size=1000, random_state=1)

@nb.njit
def expected():
np.random.seed(1)
return np.random.poisson(mu, 1000)

np.testing.assert_equal(got, expected())

0 comments on commit 51b6282

Please sign in to comment.