From fc4d9f9326139d8ef5253e9f5e4be3efca2e145e Mon Sep 17 00:00:00 2001 From: Abdul Fatir Ansari Date: Sat, 27 May 2023 00:37:12 +0200 Subject: [PATCH 1/4] Use observation mask as feature in DeepAR --- src/gluonts/torch/model/deepar/module.py | 49 ++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 4 deletions(-) diff --git a/src/gluonts/torch/model/deepar/module.py b/src/gluonts/torch/model/deepar/module.py index 5f073bc914..fb0a0595a7 100644 --- a/src/gluonts/torch/model/deepar/module.py +++ b/src/gluonts/torch/model/deepar/module.py @@ -146,7 +146,9 @@ def __init__( ) else: self.scaler = NOPScaler(dim=-1, keepdim=True) - self.rnn_input_size = len(self.lags_seq) + self._number_of_features + self.rnn_input_size = ( + 2 * len(self.lags_seq) + ) + self._number_of_features self.rnn = nn.LSTM( input_size=self.rnn_input_size, hidden_size=hidden_size, @@ -216,11 +218,13 @@ def prepare_rnn_input( past_observed_values: torch.Tensor, future_time_feat: torch.Tensor, future_target: Optional[torch.Tensor] = None, + future_observed_values: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,]: context = past_target[..., -self.context_length :] observed_context = past_observed_values[..., -self.context_length :] input, _, scale = self.scaler(context, observed_context) + observed_input = observed_context future_length = future_time_feat.shape[-2] if future_length > 1: assert future_target is not None @@ -228,11 +232,24 @@ def prepare_rnn_input( (input, future_target[..., : future_length - 1] / scale), dim=-1, ) + observed_input = torch.cat( + ( + observed_input, + future_observed_values[..., : future_length - 1], + ), + dim=-1, + ) prior_input = past_target[..., : -self.context_length] / scale + observed_prior_input = past_observed_values[ + ..., : -self.context_length + ] lags = lagged_sequence_values( self.lags_seq, prior_input, input, dim=-1 ) + observed_lags = lagged_sequence_values( + self.lags_seq, observed_prior_input, observed_input, dim=-1 + ) time_feat = torch.cat( ( @@ -252,8 +269,8 @@ def prepare_rnn_input( ) features = torch.cat((expanded_static_feat, time_feat), dim=-1) - - return torch.cat((lags, features), dim=-1), scale, static_feat + rnn_input = torch.cat((lags, observed_lags, features), dim=-1) + return (rnn_input, scale, static_feat) def unroll_lagged_rnn( self, @@ -264,6 +281,7 @@ def unroll_lagged_rnn( past_observed_values: torch.Tensor, future_time_feat: torch.Tensor, future_target: Optional[torch.Tensor] = None, + future_observed_values: Optional[torch.Tensor] = None, ) -> Tuple[ Tuple[torch.Tensor, ...], torch.Tensor, @@ -297,6 +315,9 @@ def unroll_lagged_rnn( future_target (Optional) tensor of future target values, shape: ``(batch_size, prediction_length)``. + future_observed_values + (Optional) tensor of future observed values indicators, + shape: ``(batch_size, prediction_length)``. Returns ------- @@ -316,6 +337,7 @@ def unroll_lagged_rnn( past_observed_values, future_time_feat, future_target, + future_observed_values, ) output, new_state = self.rnn(rnn_input) @@ -409,6 +431,9 @@ def forward( past_target.repeat_interleave(repeats=num_parallel_samples, dim=0) / repeated_scale ) + repeated_past_observed_values = past_observed_values.repeat_interleave( + repeats=num_parallel_samples, dim=0 + ) repeated_time_feat = future_time_feat.repeat_interleave( repeats=num_parallel_samples, dim=0 ) @@ -436,13 +461,28 @@ def forward( next_lags = lagged_sequence_values( self.lags_seq, repeated_past_target, scaled_next_sample, dim=-1 ) - rnn_input = torch.cat((next_lags, next_features), dim=-1) + next_observed_lags = lagged_sequence_values( + self.lags_seq, + repeated_past_observed_values, + torch.ones_like(scaled_next_sample), + dim=-1, + ) + rnn_input = torch.cat( + (next_lags, next_observed_lags, next_features), dim=-1 + ) output, repeated_state = self.rnn(rnn_input, repeated_state) repeated_past_target = torch.cat( (repeated_past_target, scaled_next_sample), dim=1 ) + repeated_past_observed_values = torch.cat( + ( + repeated_past_observed_values, + torch.ones_like(scaled_next_sample), + ), + dim=1, + ) params = self.param_proj(output) distr = self.output_distribution(params, scale=repeated_scale) @@ -524,6 +564,7 @@ def loss( past_observed_values, future_time_feat, future_target_reshaped, + future_observed_reshaped, ) if future_only: From 269e60fab05aceda19f044840241b4bb9adec6c7 Mon Sep 17 00:00:00 2001 From: Abdul Fatir Ansari Date: Sat, 27 May 2023 01:21:46 +0200 Subject: [PATCH 2/4] Update test --- test/torch/model/test_deepar_modules.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/torch/model/test_deepar_modules.py b/test/torch/model/test_deepar_modules.py index fa488d4264..6a4de8b339 100644 --- a/test/torch/model/test_deepar_modules.py +++ b/test/torch/model/test_deepar_modules.py @@ -79,6 +79,7 @@ def test_deepar_modules( past_observed_values, future_time_feat, future_target, + future_observed_values, ) assert scale.shape == (batch_size, 1) @@ -231,6 +232,11 @@ def test_rnn_input( dtype=torch.float32, ).view(1, prediction_length) + batch["future_observed_values"] = torch.ones( + (1, prediction_length), + dtype=torch.float32, + ) + rnn_input, scale, _ = model.prepare_rnn_input(**batch) assert (scale == 1.0).all() From 8f836982edc0629c0ee9ff86adf31ae3331a3e78 Mon Sep 17 00:00:00 2001 From: Abdul Fatir Ansari Date: Sat, 27 May 2023 01:24:55 +0200 Subject: [PATCH 3/4] Fix mypy issue --- src/gluonts/torch/model/deepar/module.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/gluonts/torch/model/deepar/module.py b/src/gluonts/torch/model/deepar/module.py index fb0a0595a7..3e27c8be02 100644 --- a/src/gluonts/torch/model/deepar/module.py +++ b/src/gluonts/torch/model/deepar/module.py @@ -227,7 +227,10 @@ def prepare_rnn_input( observed_input = observed_context future_length = future_time_feat.shape[-2] if future_length > 1: - assert future_target is not None + assert ( + future_target is not None + and future_observed_values is not None + ) input = torch.cat( (input, future_target[..., : future_length - 1] / scale), dim=-1, From 15204d569b26dd78a1cd0afee9317e378849493c Mon Sep 17 00:00:00 2001 From: Abdul Fatir Ansari Date: Sat, 27 May 2023 10:29:21 +0200 Subject: [PATCH 4/4] Fix MQF tests --- src/gluonts/torch/model/mqf2/lightning_module.py | 2 ++ test/torch/model/test_mqf2_modules.py | 1 + 2 files changed, 3 insertions(+) diff --git a/src/gluonts/torch/model/mqf2/lightning_module.py b/src/gluonts/torch/model/mqf2/lightning_module.py index 6dc824beb4..470eee8d58 100644 --- a/src/gluonts/torch/model/mqf2/lightning_module.py +++ b/src/gluonts/torch/model/mqf2/lightning_module.py @@ -96,6 +96,7 @@ def _compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: future_time_feat = batch["future_time_feat"] future_target = batch["future_target"] past_observed_values = batch["past_observed_values"] + future_observed_values = batch["future_observed_values"] picnn = self.model.picnn @@ -107,6 +108,7 @@ def _compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: past_observed_values, future_time_feat, future_target, + future_observed_values, ) hidden_state = hidden_state[:, : self.model.context_length] diff --git a/test/torch/model/test_mqf2_modules.py b/test/torch/model/test_mqf2_modules.py index 85fa21337f..451a16d890 100644 --- a/test/torch/model/test_mqf2_modules.py +++ b/test/torch/model/test_mqf2_modules.py @@ -78,6 +78,7 @@ def test_mqf2_modules( past_observed_values, future_time_feat, future_target, + future_observed_values, ) hidden_state = hidden_state[:, :context_length]