Skip to content

Use set_data in forecast method #451

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
name: pymc-extras-test
name: pymc-extras
channels:
- conda-forge
- nodefaults
dependencies:
- pymc>=5.21
- pytest-cov>=2.5
- pytest>=3.0
- python=3.12
- pymc
- pytensor
- pytest-cov
- pytest
- dask
- xhistogram
- statsmodels
- numba<=0.60.0
- numba
- better-optimize
- jax
- notebook<7
- nutpie
- pip
- pip:
- blackjax
- scikit-learn
- better_optimize
- flowjax
21 changes: 8 additions & 13 deletions pymc_extras/statespace/core/statespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,9 @@ def _kalman_filter_outputs_from_dummy_graph(
provided when the model was built.
data_dims: str or tuple of str, optional
Dimension names associated with the model data. If None, defaults to ("time", "obs_state")
scenario: dict[str, pd.DataFrame], optional
Dictionary of out-of-sample scenario dataframes. If provided, it must have values for all data variables
in the model. pm.set_data is used to replace training data with new values.

Returns
-------
Expand Down Expand Up @@ -1567,8 +1570,10 @@ def _validate_forecast_args(
raise ValueError(
"Integer start must be within the range of the data index used to fit the model."
)
if periods is None and end is None:
raise ValueError("Must specify one of either periods or end")
if periods is None and end is None and not use_scenario_index:
raise ValueError(
"Must specify one of either periods or end unless use_scenario_index=True"
)
if periods is not None and end is not None:
raise ValueError("Must specify exactly one of either periods or end")
if scenario is None and use_scenario_index:
Expand Down Expand Up @@ -2060,6 +2065,7 @@ def forecast(

with pm.Model(coords=temp_coords) as forecast_model:
(_, _, *matrices), grouped_outputs = self._kalman_filter_outputs_from_dummy_graph(
scenario=scenario,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @jessegrabowski, sorry about the constant back and forth, but I think this addition here fixes the bug. The tests still fail but the error message is different than the expected dimension mismatch it is an assertion error:

FAILED tests/statespace/test_statespace.py::test_foreacast_valid_index - UserWarning: Skipping `CheckAndRaise` Op (assertion: The first dimension of a time varying matrix (the time dimension) must be equal to the first dimension of the data (the time dimension).) as JAX tracing would remove it.

Running the test test_foreacast_valid_index in a notebook works without any problems.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just a warning, so we need to add a @pytest.filterwarnings decorator to the test.

I made this change then undid it because I thought it wasn't doing the right thing. I need to double-check, but the forecasting logic basically goes like this:

  1. The original user graph -- using the training data -- is reconstructed in its entirety.
  2. Using the original graph, we compute the value of the hidden states at the requested x0 for the forecasts. We do not want new scenario data here, otherwise the value of the requested x0 will be wrong.
  3. Starting from this x0, we construct a new graph which iterates the statespace equations forward for the requested number of time steps. If there are exogenous regressors, this is where we do want them.

I think I had the impression that this change put the scenario data in step (2), but I looked at it a few weeks ago now and I forget.

data_dims=["data_time", OBS_STATE_DIM],
)

Expand All @@ -2073,17 +2079,6 @@ def forecast(
"P0_slice", cov[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None
)

if scenario is not None:
sub_dict = {
forecast_model[data_name]: pt.as_tensor_variable(
scenario.get(data_name), name=data_name
)
for data_name in self.data_names
}

matrices = graph_replace(matrices, replace=sub_dict, strict=True)
[setattr(matrix, "name", name) for name, matrix in zip(MATRIX_NAMES[2:], matrices)]

_ = LinearGaussianStateSpace(
"forecast",
x0,
Expand Down
6 changes: 6 additions & 0 deletions tests/statespace/test_SARIMAX.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,9 @@ def test_make_SARIMA_transition_matrix(p, d, q, P, D, Q, S):
"ignore:Non-stationary starting autoregressive parameters found",
"ignore:Non-invertible starting seasonal moving average",
"ignore:Non-stationary starting seasonal autoregressive",
"ignore:divide by zero encountered in matmul:RuntimeWarning",
"ignore:overflow encountered in matmul:RuntimeWarning",
"ignore:invalid value encountered in matmul:RuntimeWarning",
)
def test_SARIMAX_update_matches_statsmodels(p, d, q, P, D, Q, S, data, rng):
sm_sarimax = sm.tsa.SARIMAX(data, order=(p, d, q), seasonal_order=(P, D, Q, S))
Expand Down Expand Up @@ -361,6 +364,9 @@ def test_interpretable_states_are_interpretable(arima_mod_interp, pymc_mod_inter
"ignore:Non-invertible starting MA parameters found.",
"ignore:Non-stationary starting autoregressive parameters found",
"ignore:Maximum Likelihood optimization failed to converge.",
"ignore:divide by zero encountered in matmul:RuntimeWarning",
"ignore:overflow encountered in matmul:RuntimeWarning",
"ignore:invalid value encountered in matmul:RuntimeWarning",
)
def test_representations_are_equivalent(p, d, q, P, D, Q, S, data, rng):
if (d + D) > 0:
Expand Down
90 changes: 90 additions & 0 deletions tests/statespace/test_statespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,3 +870,93 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
regression_effect_expected = (betas * scenario_xr).sum(dim=["state"])

assert_allclose(regression_effect, regression_effect_expected)


@pytest.mark.filterwarnings("ignore:Provided data contains missing values.")
@pytest.mark.filterwarnings("ignore:The RandomType SharedVariables")
def test_foreacast_valid_index(rng):
# Regression test for issue reported at https://github.com/pymc-devs/pymc-extras/issues/424

index = pd.date_range(start="2023-05-01", end="2025-01-29", freq="D")
T, k = len(index), 2
data = np.zeros((T, k))
idx = rng.choice(T, size=10, replace=False)
cols = rng.choice(k, size=10, replace=True)

data[idx, cols] = 1

df_holidays = pd.DataFrame(data, index=index, columns=["Holiday 1", "Holiday 2"])

data = rng.normal(size=(T, 1))
nan_locs = rng.choice(T, size=10, replace=False)
data[nan_locs] = np.nan
y = pd.DataFrame(data, index=index, columns=["sales"])

level_trend = st.LevelTrendComponent(order=1, innovations_order=[0])
weekly_seasonality = st.TimeSeasonality(
season_length=7,
state_names=["Sun", "Mon", "Tues", "Wed", "Thu", "Fri", "Sat"],
innovations=True,
remove_first_state=False,
)
quarterly_seasonality = st.FrequencySeasonality(season_length=365, n=2, innovations=True)
ar1 = st.AutoregressiveComponent(order=1)
me = st.MeasurementError()

exog = st.RegressionComponent(
name="exog", # Name of this exogenous variable component
k_exog=2, # Only one exogenous variable now
innovations=False, # Typically fixed effect (no stochastic evolution)
state_names=df_holidays.columns.tolist(),
)

combined_model = level_trend + weekly_seasonality + quarterly_seasonality + me + ar1 + exog
ss_mod = combined_model.build()

with pm.Model(coords=ss_mod.coords) as struct_model:
P0_diag = pm.Gamma("P0_diag", alpha=2, beta=10, dims=["state"])
P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=["state", "state_aux"])

initial_trend = pm.Normal("initial_trend", mu=[0], sigma=[0.005], dims=["trend_state"])
# sigma_trend = pm.Gamma("sigma_trend", alpha=2, beta=1, dims=["trend_shock"]) # Applied to the level only

Seasonal_coefs = pm.ZeroSumNormal(
"Seasonal[s=7]_coefs", sigma=0.5, dims=["Seasonal[s=7]_state"]
) # DOW dev. from weekly mean
sigma_Seasonal = pm.Gamma(
"sigma_Seasonal[s=7]", alpha=2, beta=1
) # How much this dev. can dev.

Frequency_coefs = pm.Normal(
"Frequency[s=365, n=2]", mu=0, sigma=0.5, dims=["Frequency[s=365, n=2]_state"]
) # amplitudes in short-term (weekly noise culprit)
sigma_Frequency = pm.Gamma(
"sigma_Frequency[s=365, n=2]", alpha=2, beta=1
) # smoothness & adaptability over time

ar_params = pm.Laplace("ar_params", mu=0, b=0.2, dims=["ar_lag"])
sigma_ar = pm.Gamma("sigma_ar", alpha=2, beta=1)

sigma_measurement_error = pm.HalfStudentT("sigma_MeasurementError", nu=3, sigma=1)

data_exog = pm.Data("data_exog", df_holidays.values, dims=["time", "exog_state"])
beta_exog = pm.Normal("beta_exog", mu=0, sigma=1, dims=["exog_state"])

ss_mod.build_statespace_graph(y, mode="JAX")

idata = pm.sample_prior_predictive()

post = ss_mod.sample_conditional_prior(idata)

# Define start date and forecast period
start_date, n_periods = pd.to_datetime("2024-4-15"), 8

# Extract exogenous data for the forecast period
scenario = {
"data_exog": pd.DataFrame(
df_holidays.loc[start_date:].iloc[:n_periods], columns=df_holidays.columns
)
}

# Generate the forecast
forecasts = ss_mod.forecast(idata.prior, scenario=scenario, use_scenario_index=True)
5 changes: 5 additions & 0 deletions tests/statespace/test_structural.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,11 @@ def test_autoregressive_model(order, rng):
@pytest.mark.parametrize("s", [10, 25, 50])
@pytest.mark.parametrize("innovations", [True, False])
@pytest.mark.parametrize("remove_first_state", [True, False])
@pytest.mark.filterwarnings(
"ignore:divide by zero encountered in matmul:RuntimeWarning",
"ignore:overflow encountered in matmul:RuntimeWarning",
"ignore:invalid value encountered in matmul:RuntimeWarning",
)
def test_time_seasonality(s, innovations, remove_first_state, rng):
def random_word(rng):
return "".join(rng.choice(list("abcdefghijklmnopqrstuvwxyz")) for _ in range(5))
Expand Down
Loading