From 71c900f1ea0b5d3c140e41d67488a889306b3d9d Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 8 Apr 2025 01:04:29 +0800 Subject: [PATCH 1/4] Use `set_data` in forecast --- pymc_extras/statespace/core/statespace.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/pymc_extras/statespace/core/statespace.py b/pymc_extras/statespace/core/statespace.py index 77837672..1a756bd6 100644 --- a/pymc_extras/statespace/core/statespace.py +++ b/pymc_extras/statespace/core/statespace.py @@ -1027,6 +1027,9 @@ def _kalman_filter_outputs_from_dummy_graph( provided when the model was built. data_dims: str or tuple of str, optional Dimension names associated with the model data. If None, defaults to ("time", "obs_state") + scenario: dict[str, pd.DataFrame], optional + Dictionary of out-of-sample scenario dataframes. If provided, it must have values for all data variables + in the model. pm.set_data is used to replace training data with new values. Returns ------- @@ -2060,6 +2063,7 @@ def forecast( with pm.Model(coords=temp_coords) as forecast_model: (_, _, *matrices), grouped_outputs = self._kalman_filter_outputs_from_dummy_graph( + scenario=scenario, data_dims=["data_time", OBS_STATE_DIM], ) @@ -2073,17 +2077,6 @@ def forecast( "P0_slice", cov[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None ) - if scenario is not None: - sub_dict = { - forecast_model[data_name]: pt.as_tensor_variable( - scenario.get(data_name), name=data_name - ) - for data_name in self.data_names - } - - matrices = graph_replace(matrices, replace=sub_dict, strict=True) - [setattr(matrix, "name", name) for name, matrix in zip(MATRIX_NAMES[2:], matrices)] - _ = LinearGaussianStateSpace( "forecast", x0, From c47318581b4cd0acd858ad3fa7a683cc93712e0e Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Wed, 9 Apr 2025 09:42:08 -0500 Subject: [PATCH 2/4] Ignore new numpy matmul warnings in tests --- tests/statespace/test_SARIMAX.py | 6 ++++++ tests/statespace/test_structural.py | 5 +++++ 2 files changed, 11 insertions(+) diff --git a/tests/statespace/test_SARIMAX.py b/tests/statespace/test_SARIMAX.py index 40caa9d3..7c9b831e 100644 --- a/tests/statespace/test_SARIMAX.py +++ b/tests/statespace/test_SARIMAX.py @@ -252,6 +252,9 @@ def test_make_SARIMA_transition_matrix(p, d, q, P, D, Q, S): "ignore:Non-stationary starting autoregressive parameters found", "ignore:Non-invertible starting seasonal moving average", "ignore:Non-stationary starting seasonal autoregressive", + "ignore:divide by zero encountered in matmul:RuntimeWarning", + "ignore:overflow encountered in matmul:RuntimeWarning", + "ignore:invalid value encountered in matmul:RuntimeWarning", ) def test_SARIMAX_update_matches_statsmodels(p, d, q, P, D, Q, S, data, rng): sm_sarimax = sm.tsa.SARIMAX(data, order=(p, d, q), seasonal_order=(P, D, Q, S)) @@ -361,6 +364,9 @@ def test_interpretable_states_are_interpretable(arima_mod_interp, pymc_mod_inter "ignore:Non-invertible starting MA parameters found.", "ignore:Non-stationary starting autoregressive parameters found", "ignore:Maximum Likelihood optimization failed to converge.", + "ignore:divide by zero encountered in matmul:RuntimeWarning", + "ignore:overflow encountered in matmul:RuntimeWarning", + "ignore:invalid value encountered in matmul:RuntimeWarning", ) def test_representations_are_equivalent(p, d, q, P, D, Q, S, data, rng): if (d + D) > 0: diff --git a/tests/statespace/test_structural.py b/tests/statespace/test_structural.py index c398c723..1aae6cd3 100644 --- a/tests/statespace/test_structural.py +++ b/tests/statespace/test_structural.py @@ -594,6 +594,11 @@ def test_autoregressive_model(order, rng): @pytest.mark.parametrize("s", [10, 25, 50]) @pytest.mark.parametrize("innovations", [True, False]) @pytest.mark.parametrize("remove_first_state", [True, False]) +@pytest.mark.filterwarnings( + "ignore:divide by zero encountered in matmul:RuntimeWarning", + "ignore:overflow encountered in matmul:RuntimeWarning", + "ignore:invalid value encountered in matmul:RuntimeWarning", +) def test_time_seasonality(s, innovations, remove_first_state, rng): def random_word(rng): return "".join(rng.choice(list("abcdefghijklmnopqrstuvwxyz")) for _ in range(5)) From 83bcd9430780612b79fe3dcc00aecb22bb5a403f Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Thu, 24 Apr 2025 15:27:16 -0400 Subject: [PATCH 3/4] Tracking down data bug --- pymc_extras/statespace/core/statespace.py | 6 +- tests/statespace/test_statespace.py | 90 +++++++++++++++++++++++ 2 files changed, 94 insertions(+), 2 deletions(-) diff --git a/pymc_extras/statespace/core/statespace.py b/pymc_extras/statespace/core/statespace.py index 1a756bd6..5b9f0d26 100644 --- a/pymc_extras/statespace/core/statespace.py +++ b/pymc_extras/statespace/core/statespace.py @@ -1570,8 +1570,10 @@ def _validate_forecast_args( raise ValueError( "Integer start must be within the range of the data index used to fit the model." ) - if periods is None and end is None: - raise ValueError("Must specify one of either periods or end") + if periods is None and end is None and not use_scenario_index: + raise ValueError( + "Must specify one of either periods or end unless use_scenario_index=True" + ) if periods is not None and end is not None: raise ValueError("Must specify exactly one of either periods or end") if scenario is None and use_scenario_index: diff --git a/tests/statespace/test_statespace.py b/tests/statespace/test_statespace.py index 2b0a1140..6d30b6fb 100644 --- a/tests/statespace/test_statespace.py +++ b/tests/statespace/test_statespace.py @@ -870,3 +870,93 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start): regression_effect_expected = (betas * scenario_xr).sum(dim=["state"]) assert_allclose(regression_effect, regression_effect_expected) + + +@pytest.mark.filterwarnings("ignore:Provided data contains missing values.") +@pytest.mark.filterwarnings("ignore:The RandomType SharedVariables") +def test_foreacast_valid_index(rng): + # Regression test for issue reported at https://github.com/pymc-devs/pymc-extras/issues/424 + + index = pd.date_range(start="2023-05-01", end="2025-01-29", freq="D") + T, k = len(index), 2 + data = np.zeros((T, k)) + idx = rng.choice(T, size=10, replace=False) + cols = rng.choice(k, size=10, replace=True) + + data[idx, cols] = 1 + + df_holidays = pd.DataFrame(data, index=index, columns=["Holiday 1", "Holiday 2"]) + + data = rng.normal(size=(T, 1)) + nan_locs = rng.choice(T, size=10, replace=False) + data[nan_locs] = np.nan + y = pd.DataFrame(data, index=index, columns=["sales"]) + + level_trend = st.LevelTrendComponent(order=1, innovations_order=[0]) + weekly_seasonality = st.TimeSeasonality( + season_length=7, + state_names=["Sun", "Mon", "Tues", "Wed", "Thu", "Fri", "Sat"], + innovations=True, + remove_first_state=False, + ) + quarterly_seasonality = st.FrequencySeasonality(season_length=365, n=2, innovations=True) + ar1 = st.AutoregressiveComponent(order=1) + me = st.MeasurementError() + + exog = st.RegressionComponent( + name="exog", # Name of this exogenous variable component + k_exog=2, # Only one exogenous variable now + innovations=False, # Typically fixed effect (no stochastic evolution) + state_names=df_holidays.columns.tolist(), + ) + + combined_model = level_trend + weekly_seasonality + quarterly_seasonality + me + ar1 + exog + ss_mod = combined_model.build() + + with pm.Model(coords=ss_mod.coords) as struct_model: + P0_diag = pm.Gamma("P0_diag", alpha=2, beta=10, dims=["state"]) + P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=["state", "state_aux"]) + + initial_trend = pm.Normal("initial_trend", mu=[0], sigma=[0.005], dims=["trend_state"]) + # sigma_trend = pm.Gamma("sigma_trend", alpha=2, beta=1, dims=["trend_shock"]) # Applied to the level only + + Seasonal_coefs = pm.ZeroSumNormal( + "Seasonal[s=7]_coefs", sigma=0.5, dims=["Seasonal[s=7]_state"] + ) # DOW dev. from weekly mean + sigma_Seasonal = pm.Gamma( + "sigma_Seasonal[s=7]", alpha=2, beta=1 + ) # How much this dev. can dev. + + Frequency_coefs = pm.Normal( + "Frequency[s=365, n=2]", mu=0, sigma=0.5, dims=["Frequency[s=365, n=2]_state"] + ) # amplitudes in short-term (weekly noise culprit) + sigma_Frequency = pm.Gamma( + "sigma_Frequency[s=365, n=2]", alpha=2, beta=1 + ) # smoothness & adaptability over time + + ar_params = pm.Laplace("ar_params", mu=0, b=0.2, dims=["ar_lag"]) + sigma_ar = pm.Gamma("sigma_ar", alpha=2, beta=1) + + sigma_measurement_error = pm.HalfStudentT("sigma_MeasurementError", nu=3, sigma=1) + + data_exog = pm.Data("data_exog", df_holidays.values, dims=["time", "exog_state"]) + beta_exog = pm.Normal("beta_exog", mu=0, sigma=1, dims=["exog_state"]) + + ss_mod.build_statespace_graph(y, mode="JAX") + + idata = pm.sample_prior_predictive() + + post = ss_mod.sample_conditional_prior(idata) + + # Define start date and forecast period + start_date, n_periods = pd.to_datetime("2024-4-15"), 8 + + # Extract exogenous data for the forecast period + scenario = { + "data_exog": pd.DataFrame( + df_holidays.loc[start_date:].iloc[:n_periods], columns=df_holidays.columns + ) + } + + # Generate the forecast + forecasts = ss_mod.forecast(idata.prior, scenario=scenario, use_scenario_index=True) From dc02857003ca55692cba59119ae823638e4b07a4 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Thu, 24 Apr 2025 15:27:25 -0400 Subject: [PATCH 4/4] update env file --- conda-envs/environment-test.yml | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 450b46e3..cbdd48f5 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -1,17 +1,21 @@ -name: pymc-extras-test +name: pymc-extras channels: - conda-forge - nodefaults dependencies: -- pymc>=5.21 -- pytest-cov>=2.5 -- pytest>=3.0 +- python=3.12 +- pymc +- pytensor +- pytest-cov +- pytest - dask - xhistogram - statsmodels -- numba<=0.60.0 +- numba +- better-optimize +- jax +- notebook<7 +- nutpie - pip - pip: - - blackjax - - scikit-learn - - better_optimize + - flowjax