@@ -342,7 +342,26 @@ def __init__(self, n_embd):
342
342
self .drop = nn .Dropout (0.1 )
343
343
344
344
def forward (self , x ):
345
- return self .drop (torch .cat ([self .conv3 (x ), self .conv5 (x ), self .conv7 (x ), self .conv9 (x )], 1 ))
345
+ B , T , C = x .size ()
346
+ return self .drop (torch .cat ([self .conv3 (x ), self .conv5 (x ), self .conv7 (x ), self .conv9 (x )], 1 )).view (B , T , - 1 )
347
+
348
+
349
+ class Projector2 (nn .Module ):
350
+ def __init__ (self , n_embd , shift_steps ):
351
+ super ().__init__ ()
352
+ self .conv3 = nn .Conv1d (1 , n_embd // 8 , kernel_size = 3 , stride = 1 , groups = 1 , padding = 'same' )
353
+ self .conv5 = nn .Conv1d (1 , n_embd // 8 , kernel_size = 5 , stride = 1 , groups = 1 , padding = 'same' )
354
+ self .conv7 = nn .Conv1d (1 , n_embd // 8 , kernel_size = 7 , stride = 1 , groups = 1 , padding = 'same' )
355
+ self .conv9 = nn .Conv1d (1 , n_embd // 8 , kernel_size = 9 , stride = 1 , groups = 1 , padding = 'same' )
356
+ self .linear = nn .Linear (shift_steps , n_embd // 2 )
357
+ self .drop = nn .Dropout (0.1 )
358
+
359
+ def forward (self , x , shifted_y ):
360
+ B , T , C = x .size ()
361
+ x = x .view (B , C , T )
362
+ shifted_y_emb = self .linear (shifted_y )
363
+ x_emb = torch .cat ([self .conv3 (x ), self .conv5 (x ), self .conv7 (x ), self .conv9 (x )], 1 )
364
+ return self .drop (torch .cat ([x_emb .view (B , T , - 1 ), shifted_y_emb ], 2 ))
346
365
347
366
348
367
@@ -353,7 +372,7 @@ def __init__(self, args):
353
372
self .rwkv = RWKV (args )
354
373
if args .load_model :
355
374
self .load_rwkv_from_pretrained (args .load_model )
356
- self .proj = Projector (args .n_embd )
375
+ self .proj = Projector2 (args .n_embd , args . shift_steps )
357
376
self .head = nn .Linear (args .n_embd , 1 , bias = False )
358
377
self .best_val_loss = torch .tensor (float ("inf" ))
359
378
self .do_normalize = args .do_normalize
@@ -393,12 +412,11 @@ def configure_optimizers(self):
393
412
return FusedAdam (optim_groups , lr = self .args .lr_init , betas = self .args .betas , eps = self .args .adam_eps , bias_correction = True , adam_w_mode = True , amsgrad = False )
394
413
395
414
def forward (self , samples ):
396
- x , targets = samples ["input_points" ], samples ["targets" ]
397
- B , T , C = x .shape
398
- x = self .proj (x .view (B , C , T )).view (B , T , - 1 )
415
+ x , y , shifted_y = samples ["input_points" ], samples ["targets" ], samples ["shifted_targets" ]
416
+ x = self .proj (x , shifted_y )
399
417
x = self .rwkv (x )[:, self .prefix_len :, :] #
400
418
outputs = self .head (x )
401
- return outputs , targets
419
+ return outputs , y
402
420
403
421
def bidirectional_forward (self , x , x_emb = None ):
404
422
pass
0 commit comments