From 5fc01aa1562481c03568b21bcf0a18a79947a099 Mon Sep 17 00:00:00 2001 From: meyers-academic Date: Fri, 27 Feb 2026 17:44:09 +0100 Subject: [PATCH 1/5] add generic function to take an intrinsic red noise spectrum and a common spectrum and have them share a single Fourier basis. This is an extension of `makepowerlaw_crn` to be more generic. Added tests to handle potential corner cases. --- src/discovery/signals.py | 112 ++++++++++++++++++++++++++++++++++ tests/test_signals.py | 128 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 240 insertions(+) create mode 100644 tests/test_signals.py diff --git a/src/discovery/signals.py b/src/discovery/signals.py index da31cc6f..185e9d08 100644 --- a/src/discovery/signals.py +++ b/src/discovery/signals.py @@ -830,6 +830,118 @@ def freespectrum(f, df, log10_rho: typing.Sequence): return jnp.repeat(10.0**(2.0 * log10_rho), 2) +def make_combined_crn(components, irn_psd, crn_psd, crn_prefix='crn_'): + """ + Combine an intrinsic red noise PSD and a common red noise PSD into a + single PSD function that shares the same Fourier basis. + + The intrinsic red noise PSD is evaluated over the full frequency basis, + while the common red noise PSD is added only to the first + ``2 * components`` frequency bins (sine and cosine for each component). + + Parameters + ---------- + components : int + Number of shared Fourier frequency components used by the CRN model. + This determines how many low-frequency bins of the intrinsic basis + receive the CRN contribution (specifically, the first + ``2 * components`` entries, corresponding to sine/cosine pairs). + This is *not* the same as the ``components`` argument passed to + ``makegp_fourier`` — that controls the total number of Fourier + components in the basis for the GP (and may be larger, since the + intrinsic noise can extend to higher frequencies than the CRN). + irn_psd : callable + PSD function for the intrinsic red noise. Must accept ``(f, df, ...)`` + and return a PSD array over the full basis. + crn_psd : callable + PSD function for the common red noise. Must accept ``(f, df, ...)`` + and return a PSD array. Will only be called on the first + ``2 * components`` frequency bins. + crn_prefix : str or None + Prefix applied to CRN parameter names that overlap with IRN names. + For example, if both PSDs have ``log10_A`` and ``crn_prefix='crn_'``, + the combined function will have ``log10_A`` (IRN) and + ``crn_log10_A`` (CRN) as separate parameters. + If None, overlapping names are shared (both PSDs receive the same + value), which is valid when you intentionally want tied parameters. + + Returns + ------- + combined : callable + A PSD function whose signature is the union of ``irn_psd`` and + ``crn_psd`` signatures (with CRN overlaps prefixed). Compatible + with ``makegp_fourier``: argument names are inspectable via + ``getfullargspec``, and ``typing.Sequence`` annotations are + preserved for parameter expansion. + """ + from discovery import matrix + irn_spec = inspect.getfullargspec(irn_psd) + crn_spec = inspect.getfullargspec(crn_psd) + + shared = {'f', 'df'} + irn_names = [a for a in irn_spec.args if a not in shared] + crn_names = [a for a in crn_spec.args if a not in shared] + + # Rename overlapping CRN params + irn_set = set(irn_names) + crn_rename = {} # original_name -> merged_name + for a in crn_names: + if a in irn_set and crn_prefix is not None: + crn_rename[a] = crn_prefix + a + else: + crn_rename[a] = a + + # Build merged argument list: f, df, irn params, then (renamed) crn params + merged_args = ['f', 'df'] + seen = set(shared) + for arg in irn_names: + if arg not in seen: + merged_args.append(arg) + seen.add(arg) + for arg in crn_names: + renamed = crn_rename[arg] + if renamed not in seen: + merged_args.append(renamed) + seen.add(renamed) + + # Merge annotations (applying rename to CRN annotations) + annotations = {} + if irn_spec.annotations: + annotations.update({k: v for k, v in irn_spec.annotations.items() + if k not in shared}) + if crn_spec.annotations: + for k, v in crn_spec.annotations.items(): + if k not in shared: + annotations[crn_rename.get(k, k)] = v + + def _impl(f, df, kw): + irn_kw = {k: kw[k] for k in irn_names} + crn_kw = {k: kw[crn_rename[k]] for k in crn_names} + if matrix.jnp == jnp: + phi = irn_psd(f, df, **irn_kw) + phi = phi.at[:2 * components].add( + crn_psd(f[:2 * components], df[:2 * components], **crn_kw) + ) + else: + phi = irn_psd(f, df, **irn_kw) + phi[:2 * components] += crn_psd( + f[:2 * components], df[:2 * components], **crn_kw + ) + return phi + + # Dynamically build a function with the correct inspectable signature + param_args = merged_args[2:] + args_str = ', '.join(merged_args) + kwargs_dict = '{' + ', '.join(f"'{a}': {a}" for a in param_args) + '}' + func_code = f"def combined({args_str}): return _impl(f, df, {kwargs_dict})" + ns = {'_impl': _impl} + exec(func_code, ns) + combined = ns['combined'] + combined.__annotations__ = annotations + return combined + + + # combined red_noise + crn # this is a factory because it needs to specify a different number of components for the CRN diff --git a/tests/test_signals.py b/tests/test_signals.py new file mode 100644 index 00000000..01a6b32a --- /dev/null +++ b/tests/test_signals.py @@ -0,0 +1,128 @@ +"""Tests for make_combined_crn signature merging and numerical correctness.""" + +import inspect +import numpy as np +import jax +jax.config.update('jax_enable_x64', True) +import jax.numpy as jnp +import pytest + +import discovery as ds +from discovery.signals import make_combined_crn + + +# A PSD with non-overlapping parameter names, for testing the no-rename path. +def _alt_psd(f, df, alpha, log10_ref): + return (10.0 ** (2.0 * log10_ref)) * f ** (-alpha) * df + + +def _make_freqs(n_total=30, tspan_years=20): + """Return (f, df) arrays with sin/cos pairs (2*n_total elements).""" + tspan = tspan_years * 365.25 * 86400 + f = jnp.repeat(jnp.arange(1, n_total + 1) / tspan, 2) + df = jnp.ones_like(f) / tspan + return f, df + + +# --------------------------------------------------------------------------- +# Signature tests +# --------------------------------------------------------------------------- + +class TestMakeCombinedCrnSignature: + + def test_same_function_default_prefix(self): + """Overlapping params get crn_ prefix when same function is passed twice.""" + combined = make_combined_crn(14, ds.powerlaw, ds.powerlaw) + args = inspect.getfullargspec(combined).args + assert args == ['f', 'df', 'log10_A', 'gamma', 'crn_log10_A', 'crn_gamma'], \ + f"Got args: {args}" + + def test_same_function_no_prefix_ties_params(self): + """crn_prefix=None with same function: params are tied, no duplication.""" + combined = make_combined_crn(14, ds.powerlaw, ds.powerlaw, crn_prefix=None) + args = inspect.getfullargspec(combined).args + assert args == ['f', 'df', 'log10_A', 'gamma'], f"Got args: {args}" + + def test_no_overlap_no_rename(self): + """Non-overlapping param names require no renaming.""" + combined = make_combined_crn(14, ds.powerlaw, _alt_psd) + args = inspect.getfullargspec(combined).args + assert args == ['f', 'df', 'log10_A', 'gamma', 'alpha', 'log10_ref'], \ + f"Got args: {args}" + + def test_custom_prefix(self): + """Custom prefix is applied to overlapping CRN param names.""" + combined = make_combined_crn(14, ds.powerlaw, ds.powerlaw, crn_prefix='gw_') + args = inspect.getfullargspec(combined).args + assert args == ['f', 'df', 'log10_A', 'gamma', 'gw_log10_A', 'gw_gamma'], \ + f"Got args: {args}" + + +# --------------------------------------------------------------------------- +# Numerical correctness tests +# --------------------------------------------------------------------------- + +class TestMakeCombinedCrnValues: + + def test_same_function_separate_params(self): + """phi = irn(A1,g1) + crn(A2,g2) on CRN bins; irn(A1,g1) elsewhere.""" + n_crn = 14 + combined = make_combined_crn(n_crn, ds.powerlaw, ds.powerlaw) + f, df = _make_freqs() + + log10_A, gamma = -14.5, 4.3 + crn_log10_A, crn_gamma = -15.0, 13 / 3 + + phi = combined(f, df, log10_A, gamma, crn_log10_A, crn_gamma) + irn = ds.powerlaw(f, df, log10_A, gamma) + crn = ds.powerlaw(f[:2 * n_crn], df[:2 * n_crn], crn_log10_A, crn_gamma) + + np.testing.assert_allclose(phi[:2 * n_crn], irn[:2 * n_crn] + crn, rtol=1e-6) + np.testing.assert_allclose(phi[2 * n_crn:], irn[2 * n_crn:], rtol=1e-6) + + def test_same_function_tied_params(self): + """crn_prefix=None + same function: CRN bins = 2 * irn; rest unchanged.""" + n_crn = 14 + combined = make_combined_crn(n_crn, ds.powerlaw, ds.powerlaw, crn_prefix=None) + f, df = _make_freqs() + + log10_A, gamma = -14.5, 4.3 + phi = combined(f, df, log10_A, gamma) + irn = ds.powerlaw(f, df, log10_A, gamma) + + # Both PSDs receive identical params -> CRN contribution doubles the IRN value + np.testing.assert_allclose(phi[:2 * n_crn], 2.0 * irn[:2 * n_crn], rtol=1e-6) + np.testing.assert_allclose(phi[2 * n_crn:], irn[2 * n_crn:], rtol=1e-6) + + def test_no_overlap_values(self): + """Non-overlapping PSDs: CRN bins = irn + alt_psd; rest = irn only.""" + n_crn = 14 + combined = make_combined_crn(n_crn, ds.powerlaw, _alt_psd) + f, df = _make_freqs() + + log10_A, gamma = -14.5, 4.3 + alpha, log10_ref = 3.0, -14.0 + + phi = combined(f, df, log10_A, gamma, alpha, log10_ref) + irn = ds.powerlaw(f, df, log10_A, gamma) + crn = _alt_psd(f[:2 * n_crn], df[:2 * n_crn], alpha, log10_ref) + + np.testing.assert_allclose(phi[:2 * n_crn], irn[:2 * n_crn] + crn, rtol=1e-6) + np.testing.assert_allclose(phi[2 * n_crn:], irn[2 * n_crn:], rtol=1e-6) + + def test_n_crn_boundary(self): + """CRN only affects exactly the first 2*n_crn bins.""" + n_crn = 5 + combined = make_combined_crn(n_crn, ds.powerlaw, ds.powerlaw) + f, df = _make_freqs() + + log10_A, gamma = -14.5, 4.3 + crn_log10_A, crn_gamma = -15.0, 13 / 3 + + phi = combined(f, df, log10_A, gamma, crn_log10_A, crn_gamma) + irn = ds.powerlaw(f, df, log10_A, gamma) + + # Bins beyond n_crn are untouched + np.testing.assert_allclose(phi[2 * n_crn:], irn[2 * n_crn:], rtol=1e-6) + # Bins within n_crn are strictly larger than IRN alone + assert np.all(phi[:2 * n_crn] > irn[:2 * n_crn]) From 51fe4f213220198adf99746fc9b08323bd28c611 Mon Sep 17 00:00:00 2001 From: meyers-academic Date: Fri, 27 Feb 2026 18:23:27 +0100 Subject: [PATCH 2/5] update to return common parameter names so we can use them later. --- src/discovery/signals.py | 20 +++++++++++++++++--- tests/test_signals.py | 21 +++++++++++++-------- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/src/discovery/signals.py b/src/discovery/signals.py index 185e9d08..5b1684e8 100644 --- a/src/discovery/signals.py +++ b/src/discovery/signals.py @@ -586,7 +586,6 @@ def invprior(params): # |S_ij Gamma_ab| = prod_i (|S_i Gamma_ab|) = prod_i (S_i^npsr |Gamma_ab|) # log |S_ij Gamma_ab| = log (prod_i S_i^npsr) + log prod_i |Gamma_ab| # = npsr * sum_i log S_i + nfreqs |Gamma_ab| - return (jnp.block([[jnp.make2d(val * invphi) for val in row] for row in invorf]), phi.shape[0] * orflogdet + orfmat.shape[0] * logdetphi) # was -orfmat.shape[0] * jnp.sum(jnp.log(invphidiag))) @@ -830,7 +829,7 @@ def freespectrum(f, df, log10_rho: typing.Sequence): return jnp.repeat(10.0**(2.0 * log10_rho), 2) -def make_combined_crn(components, irn_psd, crn_psd, crn_prefix='crn_'): +def make_combined_crn(components, irn_psd, crn_psd, crn_prefix: typing.Optional[str] = 'crn_'): """ Combine an intrinsic red noise PSD and a common red noise PSD into a single PSD function that shares the same Fourier basis. @@ -873,6 +872,17 @@ def make_combined_crn(components, irn_psd, crn_psd, crn_prefix='crn_'): with ``makegp_fourier``: argument names are inspectable via ``getfullargspec``, and ``typing.Sequence`` annotations are preserved for parameter expansion. + crn_params : list of str + The parameter names (as they appear in ``combined``'s signature) + that belong to the CRN PSD. Pass these directly as the ``common`` + argument to ``makegp_fourier`` or ``makecommongp_fourier`` so that + the CRN parameters are shared across pulsars rather than given + per-pulsar names. + + Example:: + + combined, crn_params = make_combined_crn(14, ds.powerlaw, ds.powerlaw) + gp = makegp_fourier(psr, combined, components=30, common=crn_params) """ from discovery import matrix irn_spec = inspect.getfullargspec(irn_psd) @@ -938,7 +948,11 @@ def _impl(f, df, kw): exec(func_code, ns) combined = ns['combined'] combined.__annotations__ = annotations - return combined + + # Deduplicated list of CRN param names as they appear in the combined signature + crn_params = list(dict.fromkeys(crn_rename[k] for k in crn_names)) + + return combined, crn_params diff --git a/tests/test_signals.py b/tests/test_signals.py index 01a6b32a..5022b9a5 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -32,30 +32,35 @@ class TestMakeCombinedCrnSignature: def test_same_function_default_prefix(self): """Overlapping params get crn_ prefix when same function is passed twice.""" - combined = make_combined_crn(14, ds.powerlaw, ds.powerlaw) + combined, crn_params = make_combined_crn(14, ds.powerlaw, ds.powerlaw) args = inspect.getfullargspec(combined).args assert args == ['f', 'df', 'log10_A', 'gamma', 'crn_log10_A', 'crn_gamma'], \ f"Got args: {args}" + assert crn_params == ['crn_log10_A', 'crn_gamma'], \ + f"Got crn_params: {crn_params}" def test_same_function_no_prefix_ties_params(self): """crn_prefix=None with same function: params are tied, no duplication.""" - combined = make_combined_crn(14, ds.powerlaw, ds.powerlaw, crn_prefix=None) + combined, crn_params = make_combined_crn(14, ds.powerlaw, ds.powerlaw, crn_prefix=None) args = inspect.getfullargspec(combined).args assert args == ['f', 'df', 'log10_A', 'gamma'], f"Got args: {args}" + assert crn_params == ['log10_A', 'gamma'], f"Got crn_params: {crn_params}" def test_no_overlap_no_rename(self): """Non-overlapping param names require no renaming.""" - combined = make_combined_crn(14, ds.powerlaw, _alt_psd) + combined, crn_params = make_combined_crn(14, ds.powerlaw, _alt_psd) args = inspect.getfullargspec(combined).args assert args == ['f', 'df', 'log10_A', 'gamma', 'alpha', 'log10_ref'], \ f"Got args: {args}" + assert crn_params == ['alpha', 'log10_ref'], f"Got crn_params: {crn_params}" def test_custom_prefix(self): """Custom prefix is applied to overlapping CRN param names.""" - combined = make_combined_crn(14, ds.powerlaw, ds.powerlaw, crn_prefix='gw_') + combined, crn_params = make_combined_crn(14, ds.powerlaw, ds.powerlaw, crn_prefix='gw_') args = inspect.getfullargspec(combined).args assert args == ['f', 'df', 'log10_A', 'gamma', 'gw_log10_A', 'gw_gamma'], \ f"Got args: {args}" + assert crn_params == ['gw_log10_A', 'gw_gamma'], f"Got crn_params: {crn_params}" # --------------------------------------------------------------------------- @@ -67,7 +72,7 @@ class TestMakeCombinedCrnValues: def test_same_function_separate_params(self): """phi = irn(A1,g1) + crn(A2,g2) on CRN bins; irn(A1,g1) elsewhere.""" n_crn = 14 - combined = make_combined_crn(n_crn, ds.powerlaw, ds.powerlaw) + combined, _ = make_combined_crn(n_crn, ds.powerlaw, ds.powerlaw) f, df = _make_freqs() log10_A, gamma = -14.5, 4.3 @@ -83,7 +88,7 @@ def test_same_function_separate_params(self): def test_same_function_tied_params(self): """crn_prefix=None + same function: CRN bins = 2 * irn; rest unchanged.""" n_crn = 14 - combined = make_combined_crn(n_crn, ds.powerlaw, ds.powerlaw, crn_prefix=None) + combined, _ = make_combined_crn(n_crn, ds.powerlaw, ds.powerlaw, crn_prefix=None) f, df = _make_freqs() log10_A, gamma = -14.5, 4.3 @@ -97,7 +102,7 @@ def test_same_function_tied_params(self): def test_no_overlap_values(self): """Non-overlapping PSDs: CRN bins = irn + alt_psd; rest = irn only.""" n_crn = 14 - combined = make_combined_crn(n_crn, ds.powerlaw, _alt_psd) + combined, _ = make_combined_crn(n_crn, ds.powerlaw, _alt_psd) f, df = _make_freqs() log10_A, gamma = -14.5, 4.3 @@ -113,7 +118,7 @@ def test_no_overlap_values(self): def test_n_crn_boundary(self): """CRN only affects exactly the first 2*n_crn bins.""" n_crn = 5 - combined = make_combined_crn(n_crn, ds.powerlaw, ds.powerlaw) + combined, _ = make_combined_crn(n_crn, ds.powerlaw, ds.powerlaw) f, df = _make_freqs() log10_A, gamma = -14.5, 4.3 From a42e0181b53b4605308ddb20a1a99861a156f0f5 Mon Sep 17 00:00:00 2001 From: meyers-academic Date: Fri, 27 Feb 2026 19:15:55 +0100 Subject: [PATCH 3/5] update docs --- docs/conf.py | 19 +-- docs/index.rst | 24 ++-- docs/tutorials/curn_example.ipynb | 215 ++++++++++++++++++++++++++++++ 3 files changed, 240 insertions(+), 18 deletions(-) create mode 100644 docs/tutorials/curn_example.ipynb diff --git a/docs/conf.py b/docs/conf.py index 1b39fa48..c7a70a6d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -27,6 +27,7 @@ 'sphinx.ext.intersphinx', 'sphinx_autodoc_typehints', 'numpydoc', + 'myst_nb', ] # Napoleon settings for numpy-style docstrings @@ -49,6 +50,9 @@ numpydoc_show_class_members = False numpydoc_class_members_toctree = False +# MyST-NB settings +nb_execution_mode = 'off' + # Autodoc settings autodoc_typehints = 'description' autodoc_member_order = 'bysource' @@ -68,20 +72,17 @@ # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output -html_theme = 'sphinx_rtd_theme' +html_theme = 'pydata_sphinx_theme' html_static_path = ['_static'] # Theme options html_theme_options = { - 'logo_only': False, - 'display_version': True, - 'prev_next_buttons_location': 'bottom', - 'style_external_links': False, - 'collapse_navigation': False, - 'sticky_navigation': True, 'navigation_depth': 4, - 'includehidden': True, - 'titles_only': False +} + +# Remove empty left sidebar from all pages +html_sidebars = { + '**': [], } # If you want to add a custom logo, uncomment and specify: diff --git a/docs/index.rst b/docs/index.rst index 093730c2..109c020c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -3,28 +3,27 @@ Welcome to Discovery's documentation! ====================================== -Discovery is a next-generation pulsar-timing-array data analysis package built on JAX. +Discovery is a next-generation PTA data analysis package built on JAX. .. toctree:: - :maxdepth: 2 + :maxdepth: 1 :caption: User Guide + tutorials/curn_example guide/overview - guide/data_model - guide/pulsar_data installation .. toctree:: - :maxdepth: 2 + :maxdepth: 1 :caption: Tutorials quickstart tutorials/basic_likelihood - tutorials/simulations tutorials/optimal_statistic + tutorials/simulations .. toctree:: - :maxdepth: 2 + :maxdepth: 1 :caption: Component Reference components/noise_signals @@ -32,13 +31,20 @@ Discovery is a next-generation pulsar-timing-array data analysis package built o components/delays .. toctree:: - :maxdepth: 2 + :maxdepth: 1 + :caption: Other Useful Information + + guide/data_model + guide/pulsar_data + +.. toctree:: + :maxdepth: 1 :caption: Advanced Topics advanced/conditional_sampling .. toctree:: - :maxdepth: 2 + :maxdepth: 1 :caption: API Reference api/index diff --git a/docs/tutorials/curn_example.ipynb b/docs/tutorials/curn_example.ipynb new file mode 100644 index 00000000..b4df0454 --- /dev/null +++ b/docs/tutorials/curn_example.ipynb @@ -0,0 +1,215 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ce3d00c3-eb23-4686-80fc-090913a839da", + "metadata": {}, + "source": [ + "# How best to run a CURN model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ff14aca8-d77b-4ae1-baea-3ab4429efa51", + "metadata": {}, + "outputs": [], + "source": [ + "import discovery as ds\n", + "import jax\n", + "import numpy as np\n", + "import jax.numpy as jnp\n", + "import glob\n", + "from pathlib import Path\n", + "import discovery as ds\n", + "import discovery.samplers.numpyro as ds_numpyro\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7216e4a9-b649-483d-b9ed-d9332b52a418", + "metadata": {}, + "outputs": [], + "source": [ + "datapath = Path(ds.__path__[0]) / '../../data'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "43e38887-eba5-4791-9376-2e40ac43ca38", + "metadata": {}, + "outputs": [], + "source": [ + "psrs = [ds.Pulsar.read_feather(f) for f in sorted(datapath.glob('*v1p1*.feather'))][:10] # only 10 pulsars for now" + ] + }, + { + "cell_type": "markdown", + "id": "b1e4c150-d086-43b8-a421-76455936b3c3", + "metadata": {}, + "source": [ + "## Making the model \n", + "{func}`~discovery.signals.make_combined_crn` will combine intrinsic red noise and common nosie\n", + "into a single GP that can use a single Fourier basis.\n", + "since this fixes the names of the common process when it's created, we also return those parameters\n", + "so you can pass them later on.\n", + "\n", + "If you want different Fourier bases (different T-spans) for different pulsars, then\n", + "you will have to do something different." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6dc5fd0b-fe92-41b9-b619-209ba7cc06ab", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# common_parnames are the names of parameters\n", + "# that are shared for all pulsars.\n", + "mypl, common_parnames = ds.make_combined_crn(14, ds.powerlaw, ds.powerlaw, crn_prefix='gw_')\n", + "\n", + "psls = [ds.PulsarLikelihood([psr.residuals,\n", + " ds.makenoise_measurement(psr, psr.noisedict),\n", + " ds.makegp_ecorr(psr, psr.noisedict),\n", + " ds.makegp_timing(psr, svd=True)]) for psr in psrs]\n", + "\n", + "commongp = ds.makecommongp_fourier(psrs, mypl, 30, T=ds.getspan(psrs), name='red_noise',\n", + " common=common_parnames)\n", + "\n", + "array_likelihood = ds.ArrayLikelihood(psls, commongp=commongp)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "881cc1d5-3fe1-4a05-b415-452e37a99b36", + "metadata": {}, + "outputs": [], + "source": [ + "# array_likelihood.logL.params" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c80e2067-7a65-4283-bb86-1582673755a6", + "metadata": {}, + "outputs": [], + "source": [ + "test_params = ds.sample_uniform(array_likelihood.logL.params)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "95ab3572-45ef-4391-ae5c-7b1d6dcc7997", + "metadata": {}, + "outputs": [], + "source": [ + "jlogl = jax.jit(array_likelihood.logL)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "73c26aa8-11f4-49be-8f4d-0b55a9d44945", + "metadata": {}, + "outputs": [], + "source": [ + "jlogl(test_params)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e73229c5-b6fa-4a2a-9916-6399f6d014f9", + "metadata": {}, + "outputs": [], + "source": [ + "%%timeit\n", + "jlogl(test_params)" + ] + }, + { + "cell_type": "markdown", + "id": "e212fdca-2048-47e2-9c9e-ae7b8de33bfa", + "metadata": {}, + "source": [ + "## Variable transformations and performance\n", + "This does a transformation so that the parameters that get sampled\n", + "live on the full real line instead of uniform in a fixed range\n", + "this helps with NUTS sampling\n", + "\n", + "ATTENTION!!! In creating this transformed likelihood, JAX\n", + "actually bypasses the parameter dictionary completely\n", + "when it is compiled. This seems to give a large performance benefit on GPUs, where the dictionary\n", + "rolling and unrolling seems to cause significant overhead.\n", + "\n", + "For both the sampling reason, and this performance reason, I'd recommend using these transformations if possible. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a5cff402-90f7-478a-a3bd-6da2ec764aae", + "metadata": {}, + "outputs": [], + "source": [ + "npmodel = ds_numpyro.makemodel_transformed(jlogl)\n", + "npsampler = ds_numpyro.makesampler_nuts(npmodel,\n", + " num_warmup=100,\n", + " num_samples=100)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1b7ca6d5-34b6-47b0-b65a-f14296bd20c2", + "metadata": {}, + "outputs": [], + "source": [ + "npsampler.run(jax.random.key(0))\n", + "chain = npsampler.to_df()\n", + "# chain.to_csv('chain.feather', index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "52cdf3c4-bf1b-4e9e-b068-e8687249bcb2", + "metadata": {}, + "outputs": [], + "source": [ + "plt.hist(chain['gw_gamma'])\n", + "plt.xlabel(\"$\\gamma_{gw}$\")\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From cfda713754d5505278893522482ed080b9331076 Mon Sep 17 00:00:00 2001 From: meyers-academic Date: Fri, 27 Feb 2026 19:16:15 +0100 Subject: [PATCH 4/5] update docs. --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 685dfbcf..1530e076 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,9 +69,10 @@ flow = [ docs = [ "sphinx>=7.0", - "sphinx-rtd-theme>=2.0", + "pydata-sphinx-theme>=0.15", "sphinx-autodoc-typehints>=1.25", "numpydoc>=1.6", + "myst-nb>=1.0", ] [project.urls] From 6419d8af1f354d5b472e7857a3a8d4f0d2b26532 Mon Sep 17 00:00:00 2001 From: meyers-academic Date: Fri, 27 Feb 2026 19:18:12 +0100 Subject: [PATCH 5/5] add docs to readme --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index fc58a22c..c84f1aed 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,8 @@ # Discovery [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.17711453.svg)](https://doi.org/10.5281/zenodo.17711453) +[Sphinx Documentation is here](https://nanograv.github.io/discovery/). The old README user guide is below as well. + Logo _Discovery_ is a next-generation pulsar-timing-array data-analysis package, _built for speed_ on a [JAX](https://jax.readthedocs.io/en/latest/) backend that supports GPU execution and autodifferentiation.