Skip to content

Commit 95f644b

Browse files
committedJul 18, 2024
add new channel support
1 parent fdd9af3 commit 95f644b

File tree

3 files changed

+45
-8
lines changed

3 files changed

+45
-8
lines changed
 

‎Pretrained-RWKV_TS/src/dataset.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,21 @@ def build_test_set(self):
2525
data_df = df[keys_sorted[-1]]
2626
X = data_df["nwp_ws100"].to_numpy()[:, np.newaxis]
2727
y = data_df["fj_windSpeed"].to_numpy()[:, np.newaxis]
28+
print(f"input shape: {X.shape}, target shape: {y.shape}")
29+
# shift the fj_windSpeed
30+
shifted_list = []
31+
for i in range(1, self.args.shift_steps+1):
32+
shifted_windspeed = data_df["fj_windSpeed"].shift(i).fillna(0).to_numpy()[:, np.newaxis]
33+
shifted_list.append(shifted_windspeed)
34+
self.shifted_y = np.concatenate(shifted_list, axis=1)
2835
# split X, y to chunk
2936
X_chunks = [X[i-self.prefix_len:i+self.seq_len] for i in range(0, len(X), self.seq_len) if i != 0]
3037
y_chunks = [y[i:i+self.seq_len] for i in range(0, len(y), self.seq_len) if i != 0]
38+
y_shifted_chunks = [self.shifted_y[i:i+self.seq_len] for i in range(0, len(self.shifted_y), self.seq_len) if i != 0]
3139
print(f"input chunks: {len(X_chunks)}, target chunks: {len(y_chunks)}")
3240
self.X = X_chunks
3341
self.y = y_chunks
42+
self.shifted_y = y_shifted_chunks
3443

3544
def __len__(self):
3645
return len(self.X)
@@ -41,7 +50,8 @@ def __getitem__(self, idx):
4150
else:
4251
input_points = self.X[idx]
4352
targets = self.y[idx]
44-
return dict(input_points=input_points, targets=targets)
53+
shifted_targets = self.shifted_y[idx]
54+
return dict(input_points=input_points, targets=targets, shifted_targets=shifted_targets)
4555

4656

4757
class TrainDataset(Dataset):
@@ -76,6 +86,12 @@ def build_train_set(self):
7686
self.X_std = self.X.std()
7787
self.y_mean = self.y.mean()
7888
self.y_std = self.y.std()
89+
# shift the fj_windSpeed
90+
shifted_list = []
91+
for i in range(1, self.args.shift_steps+1):
92+
shifted_windspeed = data_df["fj_windSpeed"].shift(i).fillna(0).to_numpy()[:, np.newaxis]
93+
shifted_list.append(shifted_windspeed)
94+
self.shifted_y = np.concatenate(shifted_list, axis=1)
7995

8096
def __len__(self):
8197
return self.args.epoch_steps * self.args.micro_bsz
@@ -86,7 +102,9 @@ def __getitem__(self, idx):
86102
if self.do_normalize:
87103
input_points = (self.X[s-self.prefix_len:s+self.seq_len] - self.X_mean) / self.X_std
88104
targets = (self.y[s:s+self.seq_len] - self.y_mean) / self.y_std
105+
shifted_targets = (self.shifted_y[s:s+self.seq_len] - self.y_mean) / self.y_std
89106
else:
90107
input_points = self.X[s-self.prefix_len:s+self.seq_len] # include the previous seq_len points
91108
targets = self.y[s:s+self.seq_len]
92-
return dict(input_points=input_points, targets=targets)
109+
shifted_targets = self.shifted_y[s:s+self.seq_len]
110+
return dict(input_points=input_points, targets=targets, shifted_targets=shifted_targets)

‎Pretrained-RWKV_TS/src/model.py

+24-6
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,26 @@ def __init__(self, n_embd):
342342
self.drop = nn.Dropout(0.1)
343343

344344
def forward(self, x):
345-
return self.drop(torch.cat([self.conv3(x), self.conv5(x), self.conv7(x), self.conv9(x)], 1))
345+
B, T, C = x.size()
346+
return self.drop(torch.cat([self.conv3(x), self.conv5(x), self.conv7(x), self.conv9(x)], 1)).view(B, T, -1)
347+
348+
349+
class Projector2(nn.Module):
350+
def __init__(self, n_embd, shift_steps):
351+
super().__init__()
352+
self.conv3 = nn.Conv1d(1, n_embd//8, kernel_size=3, stride=1, groups=1, padding='same')
353+
self.conv5 = nn.Conv1d(1, n_embd//8, kernel_size=5, stride=1, groups=1, padding='same')
354+
self.conv7 = nn.Conv1d(1, n_embd//8, kernel_size=7, stride=1, groups=1, padding='same')
355+
self.conv9 = nn.Conv1d(1, n_embd//8, kernel_size=9, stride=1, groups=1, padding='same')
356+
self.linear = nn.Linear(shift_steps, n_embd//2)
357+
self.drop = nn.Dropout(0.1)
358+
359+
def forward(self, x, shifted_y):
360+
B, T, C = x.size()
361+
x = x.view(B, C, T)
362+
shifted_y_emb = self.linear(shifted_y)
363+
x_emb = torch.cat([self.conv3(x), self.conv5(x), self.conv7(x), self.conv9(x)], 1)
364+
return self.drop(torch.cat([x_emb.view(B, T, -1), shifted_y_emb], 2))
346365

347366

348367

@@ -353,7 +372,7 @@ def __init__(self, args):
353372
self.rwkv = RWKV(args)
354373
if args.load_model:
355374
self.load_rwkv_from_pretrained(args.load_model)
356-
self.proj = Projector(args.n_embd)
375+
self.proj = Projector2(args.n_embd, args.shift_steps)
357376
self.head = nn.Linear(args.n_embd, 1, bias=False)
358377
self.best_val_loss = torch.tensor(float("inf"))
359378
self.do_normalize = args.do_normalize
@@ -393,12 +412,11 @@ def configure_optimizers(self):
393412
return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=True, amsgrad=False)
394413

395414
def forward(self, samples):
396-
x, targets = samples["input_points"], samples["targets"]
397-
B, T, C = x.shape
398-
x = self.proj(x.view(B, C, T)).view(B, T, -1)
415+
x, y, shifted_y = samples["input_points"], samples["targets"], samples["shifted_targets"]
416+
x = self.proj(x, shifted_y)
399417
x = self.rwkv(x)[:, self.prefix_len:, :] #
400418
outputs = self.head(x)
401-
return outputs, targets
419+
return outputs, y
402420

403421
def bidirectional_forward(self, x, x_emb=None):
404422
pass

‎Pretrained-RWKV_TS/train.py

+1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
parser.add_argument("--label_smoothing", default=0, type=int) # label smoothing window
5858
parser.add_argument("--prefix_len", default=0, type=int) #
5959
parser.add_argument("--do_normalize", action="store_true") # normalize input
60+
parser.add_argument("--shift_steps", default=5, type=int) # shift steps for fj_windSpeed
6061

6162
parser = Trainer.add_argparse_args(parser)
6263
args = parser.parse_args()

0 commit comments

Comments
 (0)
Please sign in to comment.