@@ -15,9 +15,7 @@ def test_additive_attention():
15
15
enc_seq_len = torch .arange (start = 10 , end = 20 ) # [10, ..., 19]
16
16
17
17
# pass key as weight feedback just for testing
18
- context , weights = att (
19
- key = key , value = value , query = query , weight_feedback = key , enc_seq_len = enc_seq_len , device = "cpu"
20
- )
18
+ context , weights = att (key = key , value = value , query = query , weight_feedback = key , enc_seq_len = enc_seq_len )
21
19
assert context .shape == (10 , 5 )
22
20
assert weights .shape == (10 , 20 , 1 )
23
21
@@ -42,7 +40,6 @@ def test_encoder_decoder_attention_model():
42
40
output_dropout = 0.1 ,
43
41
zoneout_drop_c = 0.0 ,
44
42
zoneout_drop_h = 0.0 ,
45
- device = "cpu" ,
46
43
)
47
44
decoder = AttentionLSTMDecoderV1 (decoder_cfg )
48
45
target_labels = torch .randint (low = 0 , high = 15 , size = (10 , 7 )) # [B,N]
@@ -69,7 +66,6 @@ def forward_decoder(zoneout_drop_c: float, zoneout_drop_h: float):
69
66
output_dropout = 0.1 ,
70
67
zoneout_drop_c = zoneout_drop_c ,
71
68
zoneout_drop_h = zoneout_drop_h ,
72
- device = "cpu" ,
73
69
)
74
70
decoder = AttentionLSTMDecoderV1 (decoder_cfg )
75
71
decoder_logits , _ = decoder (encoder_outputs = encoder , labels = target_labels , enc_seq_len = encoder_seq_len )
0 commit comments