diff --git a/tests/test_alignment_crf.py b/tests/test_alignment_crf.py new file mode 100644 index 0000000..851796f --- /dev/null +++ b/tests/test_alignment_crf.py @@ -0,0 +1,24 @@ +import torch +import torch_struct +import pytest + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='needs CUDA') +def test_alignment_crf_shapes(): + batch, N, M = 2, 4, 5 + log_potentials = torch.rand(batch, N, M, 3).cuda() + + dist = torch_struct.AlignmentCRF(log_potentials) + assert (batch, N, M, 3) == dist.argmax.shape + assert (batch, N, M, 3) == dist.marginals.shape + assert (batch,) == dist.partition.shape + + # Fail due to AttributeError: 'BandedMatrix' object has no attribute + # 'unsqueeze' + assert (batch,) == dist.entropy.shape + # assert (9, batch, N, M, 3) == dist.sample([9]).shape + + # Fails due to: RuntimeError: Expected condition, x and y to be on + # the same device, but condition is on cpu and x and y are on + # cuda:0 and cuda:0 respectively + # assert (8, batch,) == dist.topk(8).shape diff --git a/tests/test_matmul.py b/tests/test_matmul.py new file mode 100644 index 0000000..968d7a6 --- /dev/null +++ b/tests/test_matmul.py @@ -0,0 +1,14 @@ +import torch +from hypothesis import given +from hypothesis.strategies import integers +import genbmm + +bint = integers(min_value=1, max_value=4) +mint = integers(min_value=6, max_value=8) +nint = integers(min_value=3, max_value=5) +kint = integers(min_value=9, max_value=11) + + +@given(bint, mint, nint, kint) +def test_matmul(batch, m, n, k): + a, b = torch.rand((m, n)), torch.rand((n, k)) diff --git a/tests/test_semirings.py b/tests/test_semirings.py index ab83a1c..d82ed28 100644 --- a/tests/test_semirings.py +++ b/tests/test_semirings.py @@ -1,5 +1,5 @@ import torch -from hypothesis import given +from hypothesis import given, settings from hypothesis.strategies import integers @@ -17,6 +17,7 @@ @given(lint, lint, lint) +@settings(deadline=None) # Avoid spurious warnings when first run def test_max(a, b, c): torch.manual_seed(0) t1 = torch.rand(a, 1, c).requires_grad_(True) diff --git a/torch_struct/alignment.py b/torch_struct/alignment.py index c9840c9..a2a4bb0 100644 --- a/torch_struct/alignment.py +++ b/torch_struct/alignment.py @@ -1,11 +1,14 @@ import torch from .helpers import _Struct import math +import warnings try: import genbmm + except ImportError: - pass + warnings.warn('Could not import genbmm. ' + 'However, genbmm is only used for CUDA operations.') from .semirings import LogSemiring from .semirings.fast_semirings import broadcast @@ -97,9 +100,9 @@ def _dp_scan(self, log_potentials, lengths=None, force_grad=False): # Create finalizing paths. point = (l + M) // 2 - charta[1][:, b, point:, 1, ind, :, :, Mid] = semiring.one_( - charta[1][:, b, point:, 1, ind, :, :, Mid] - ) + init = torch.zeros(charta[1].shape, device=charta[1].device).bool() + init[:, b, point:, 1, ind, :, :, Mid].fill_(True) + charta[1] = semiring.fill(charta[1], init, semiring.one) for b in range(lengths.shape[0]): point = (lengths[b] + M) // 2 diff --git a/torch_struct/semirings/sample.py b/torch_struct/semirings/sample.py index 09ec189..f228a0c 100644 --- a/torch_struct/semirings/sample.py +++ b/torch_struct/semirings/sample.py @@ -167,6 +167,9 @@ def forward(ctx, input, dim): def backward(ctx, grad_output): logits, part, dim = ctx.saved_tensors + # Replace infinite logits with max float, otherwise softmax gives NaNs + # Perhaps this could be done earlier (during forward pass)? + logits[logits == float('inf')] = torch.finfo(logits.dtype).max grad_input = None if ctx.needs_input_grad[0]: diff --git a/torch_struct/semirings/semirings.py b/torch_struct/semirings/semirings.py index cfc2311..6355f1e 100644 --- a/torch_struct/semirings/semirings.py +++ b/torch_struct/semirings/semirings.py @@ -140,7 +140,6 @@ class LogSemiring(_BaseLog): Gradients give marginals. """ - @classmethod def matmul(cls, a, b): if has_genbmm and isinstance(a, genbmm.BandedMatrix): @@ -192,7 +191,7 @@ def convert(cls, orig_potentials): dtype=orig_potentials.dtype, device=orig_potentials.device, ) - potentials = cls.fill(potentials, torch.tensor(True), cls.zero) + potentials = cls.fill(potentials, torch.tensor(True, device=potentials.device), cls.zero.to(potentials.device)) potentials[0] = orig_potentials return potentials