Skip to content

Commit 6a6e044

Browse files
committed
remove device from attention test
1 parent 315b579 commit 6a6e044

File tree

1 file changed

+1
-5
lines changed

1 file changed

+1
-5
lines changed

tests/test_enc_dec_att.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@ def test_additive_attention():
1515
enc_seq_len = torch.arange(start=10, end=20) # [10, ..., 19]
1616

1717
# pass key as weight feedback just for testing
18-
context, weights = att(
19-
key=key, value=value, query=query, weight_feedback=key, enc_seq_len=enc_seq_len, device="cpu"
20-
)
18+
context, weights = att(key=key, value=value, query=query, weight_feedback=key, enc_seq_len=enc_seq_len)
2119
assert context.shape == (10, 5)
2220
assert weights.shape == (10, 20, 1)
2321

@@ -42,7 +40,6 @@ def test_encoder_decoder_attention_model():
4240
output_dropout=0.1,
4341
zoneout_drop_c=0.0,
4442
zoneout_drop_h=0.0,
45-
device="cpu",
4643
)
4744
decoder = AttentionLSTMDecoderV1(decoder_cfg)
4845
target_labels = torch.randint(low=0, high=15, size=(10, 7)) # [B,N]
@@ -69,7 +66,6 @@ def forward_decoder(zoneout_drop_c: float, zoneout_drop_h: float):
6966
output_dropout=0.1,
7067
zoneout_drop_c=zoneout_drop_c,
7168
zoneout_drop_h=zoneout_drop_h,
72-
device="cpu",
7369
)
7470
decoder = AttentionLSTMDecoderV1(decoder_cfg)
7571
decoder_logits, _ = decoder(encoder_outputs=encoder, labels=target_labels, enc_seq_len=encoder_seq_len)

0 commit comments

Comments
 (0)