From 8635f0e284186d440736899308e1b2f5475a54d5 Mon Sep 17 00:00:00 2001 From: John Reid Date: Fri, 1 Oct 2021 12:58:11 +0000 Subject: [PATCH 01/13] Add alignment CRF test. Fix missing fill_() --- tests/test_alignment_crf.py | 22 ++++++++++++++++++++++ torch_struct/alignment.py | 9 +++++---- 2 files changed, 27 insertions(+), 4 deletions(-) create mode 100644 tests/test_alignment_crf.py diff --git a/tests/test_alignment_crf.py b/tests/test_alignment_crf.py new file mode 100644 index 0000000..9b5d199 --- /dev/null +++ b/tests/test_alignment_crf.py @@ -0,0 +1,22 @@ +import torch +import torch_struct +import warnings + + +def test_alignment_crf(): + batch, N, M = 1, 4, 5 + log_potentials = torch.rand(batch, N, M, 3).cuda() + + try: + log_potentials = log_potentials.cuda() + on_cuda = True + + except Exception: + warnings.warn('Could not move log potentials to CUDA device. ' + 'Will not test marginals.') + on_cuda = False + + dist = torch_struct.AlignmentCRF(log_potentials) + assert (N, M, 3) == dist.argmax[0].shape + if on_cuda: + assert (N, M, 3) == dist.marginals[0].shape diff --git a/torch_struct/alignment.py b/torch_struct/alignment.py index c9840c9..797b8da 100644 --- a/torch_struct/alignment.py +++ b/torch_struct/alignment.py @@ -1,11 +1,14 @@ import torch from .helpers import _Struct import math +import warnings try: import genbmm + except ImportError: - pass + warnings.warn('Could not import genbmm. ' + 'However, genbmm is only used for CUDA operations.') from .semirings import LogSemiring from .semirings.fast_semirings import broadcast @@ -97,9 +100,7 @@ def _dp_scan(self, log_potentials, lengths=None, force_grad=False): # Create finalizing paths. point = (l + M) // 2 - charta[1][:, b, point:, 1, ind, :, :, Mid] = semiring.one_( - charta[1][:, b, point:, 1, ind, :, :, Mid] - ) + charta[1][:, b, point:, 1, ind, :, :, Mid] = charta[1][:, b, point:, 1, ind, :, :, Mid].fill_(0) for b in range(lengths.shape[0]): point = (lengths[b] + M) // 2 From e9a2bc31355759a8dc27f255a9461c940a3df8f7 Mon Sep 17 00:00:00 2001 From: John Reid Date: Fri, 1 Oct 2021 13:08:48 +0000 Subject: [PATCH 02/13] Fix exception type. --- tests/test_alignment_crf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_alignment_crf.py b/tests/test_alignment_crf.py index 9b5d199..7699221 100644 --- a/tests/test_alignment_crf.py +++ b/tests/test_alignment_crf.py @@ -11,7 +11,7 @@ def test_alignment_crf(): log_potentials = log_potentials.cuda() on_cuda = True - except Exception: + except RuntimeError: warnings.warn('Could not move log potentials to CUDA device. ' 'Will not test marginals.') on_cuda = False From acbce533e174c19b41898d4ee80f46a1bfb95f10 Mon Sep 17 00:00:00 2001 From: John Reid Date: Fri, 1 Oct 2021 13:15:05 +0000 Subject: [PATCH 03/13] Change CUDA detection --- tests/test_alignment_crf.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/test_alignment_crf.py b/tests/test_alignment_crf.py index 7699221..aeb2356 100644 --- a/tests/test_alignment_crf.py +++ b/tests/test_alignment_crf.py @@ -7,16 +7,13 @@ def test_alignment_crf(): batch, N, M = 1, 4, 5 log_potentials = torch.rand(batch, N, M, 3).cuda() - try: + if torch.cuda.is_available(): log_potentials = log_potentials.cuda() - on_cuda = True - - except RuntimeError: + else: warnings.warn('Could not move log potentials to CUDA device. ' 'Will not test marginals.') - on_cuda = False dist = torch_struct.AlignmentCRF(log_potentials) assert (N, M, 3) == dist.argmax[0].shape - if on_cuda: + if torch.cuda.is_available(): assert (N, M, 3) == dist.marginals[0].shape From e65832e6798f2301680f8ba71c3da91f4d2dfb72 Mon Sep 17 00:00:00 2001 From: John Reid Date: Fri, 1 Oct 2021 13:18:42 +0000 Subject: [PATCH 04/13] Remove unwanted `.cuda()` --- tests/test_alignment_crf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_alignment_crf.py b/tests/test_alignment_crf.py index aeb2356..a9a6701 100644 --- a/tests/test_alignment_crf.py +++ b/tests/test_alignment_crf.py @@ -5,7 +5,7 @@ def test_alignment_crf(): batch, N, M = 1, 4, 5 - log_potentials = torch.rand(batch, N, M, 3).cuda() + log_potentials = torch.rand(batch, N, M, 3) if torch.cuda.is_available(): log_potentials = log_potentials.cuda() From 9788879b604739182a5fde501803f2e581247cd6 Mon Sep 17 00:00:00 2001 From: John Reid Date: Fri, 1 Oct 2021 15:53:20 +0000 Subject: [PATCH 05/13] Add another property to test. --- tests/test_alignment_crf.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/tests/test_alignment_crf.py b/tests/test_alignment_crf.py index a9a6701..7865fb6 100644 --- a/tests/test_alignment_crf.py +++ b/tests/test_alignment_crf.py @@ -3,8 +3,8 @@ import warnings -def test_alignment_crf(): - batch, N, M = 1, 4, 5 +def test_alignment_crf_shapes(): + batch, N, M = 2, 4, 5 log_potentials = torch.rand(batch, N, M, 3) if torch.cuda.is_available(): @@ -14,6 +14,17 @@ def test_alignment_crf(): 'Will not test marginals.') dist = torch_struct.AlignmentCRF(log_potentials) - assert (N, M, 3) == dist.argmax[0].shape + assert (batch, N, M, 3) == dist.argmax.shape if torch.cuda.is_available(): - assert (N, M, 3) == dist.marginals[0].shape + assert (batch, N, M, 3) == dist.marginals.shape + assert (batch,) == dist.partition.shape + + # Fail due to AttributeError: 'BandedMatrix' object has no attribute + # 'unsqueeze' + # assert (batch,) == dist.entropy.shape + # assert (9, batch, N, M, 3) == dist.sample([9]).shape + + # Fails due to: RuntimeError: Expected condition, x and y to be on + # the same device, but condition is on cpu and x and y are on + # cuda:0 and cuda:0 respectively + # assert (8, batch,) == dist.topk(8).shape From 7d140f639f270102de8a88e4edd32132b9911698 Mon Sep 17 00:00:00 2001 From: John Reid Date: Wed, 13 Oct 2021 07:54:51 +0000 Subject: [PATCH 06/13] Update test to use skipif --- tests/test_alignment_crf.py | 31 +++++++++++++------------------ 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/tests/test_alignment_crf.py b/tests/test_alignment_crf.py index 7865fb6..c2bc24f 100644 --- a/tests/test_alignment_crf.py +++ b/tests/test_alignment_crf.py @@ -1,30 +1,25 @@ import torch import torch_struct import warnings +import pytest +@pytest.skipif(not torch.cuda.isavailable(), 'needs CUDA') def test_alignment_crf_shapes(): batch, N, M = 2, 4, 5 - log_potentials = torch.rand(batch, N, M, 3) - - if torch.cuda.is_available(): - log_potentials = log_potentials.cuda() - else: - warnings.warn('Could not move log potentials to CUDA device. ' - 'Will not test marginals.') + log_potentials = torch.rand(batch, N, M, 3).cuda() dist = torch_struct.AlignmentCRF(log_potentials) assert (batch, N, M, 3) == dist.argmax.shape - if torch.cuda.is_available(): - assert (batch, N, M, 3) == dist.marginals.shape - assert (batch,) == dist.partition.shape + assert (batch, N, M, 3) == dist.marginals.shape + assert (batch,) == dist.partition.shape - # Fail due to AttributeError: 'BandedMatrix' object has no attribute - # 'unsqueeze' - # assert (batch,) == dist.entropy.shape - # assert (9, batch, N, M, 3) == dist.sample([9]).shape + # Fail due to AttributeError: 'BandedMatrix' object has no attribute + # 'unsqueeze' + assert (batch,) == dist.entropy.shape + assert (9, batch, N, M, 3) == dist.sample([9]).shape - # Fails due to: RuntimeError: Expected condition, x and y to be on - # the same device, but condition is on cpu and x and y are on - # cuda:0 and cuda:0 respectively - # assert (8, batch,) == dist.topk(8).shape + # Fails due to: RuntimeError: Expected condition, x and y to be on + # the same device, but condition is on cpu and x and y are on + # cuda:0 and cuda:0 respectively + assert (8, batch,) == dist.topk(8).shape From 92cd4ebe22c461589026619d86da236de2b32cdb Mon Sep 17 00:00:00 2001 From: John Reid Date: Wed, 13 Oct 2021 07:55:16 +0000 Subject: [PATCH 07/13] Use code from @srush in #109 --- torch_struct/alignment.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch_struct/alignment.py b/torch_struct/alignment.py index 797b8da..a2a4bb0 100644 --- a/torch_struct/alignment.py +++ b/torch_struct/alignment.py @@ -100,7 +100,9 @@ def _dp_scan(self, log_potentials, lengths=None, force_grad=False): # Create finalizing paths. point = (l + M) // 2 - charta[1][:, b, point:, 1, ind, :, :, Mid] = charta[1][:, b, point:, 1, ind, :, :, Mid].fill_(0) + init = torch.zeros(charta[1].shape, device=charta[1].device).bool() + init[:, b, point:, 1, ind, :, :, Mid].fill_(True) + charta[1] = semiring.fill(charta[1], init, semiring.one) for b in range(lengths.shape[0]): point = (lengths[b] + M) // 2 From 4fa50a8d75a53945af7a6547da4f1543e7dc1011 Mon Sep 17 00:00:00 2001 From: John Reid Date: Wed, 13 Oct 2021 07:56:30 +0000 Subject: [PATCH 08/13] Clamp infinite logits to fix sampling --- torch_struct/semirings/sample.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch_struct/semirings/sample.py b/torch_struct/semirings/sample.py index 09ec189..f228a0c 100644 --- a/torch_struct/semirings/sample.py +++ b/torch_struct/semirings/sample.py @@ -167,6 +167,9 @@ def forward(ctx, input, dim): def backward(ctx, grad_output): logits, part, dim = ctx.saved_tensors + # Replace infinite logits with max float, otherwise softmax gives NaNs + # Perhaps this could be done earlier (during forward pass)? + logits[logits == float('inf')] = torch.finfo(logits.dtype).max grad_input = None if ctx.needs_input_grad[0]: From 995b745679aa1003294b9e798c3038fbef868b67 Mon Sep 17 00:00:00 2001 From: John Reid Date: Wed, 13 Oct 2021 07:57:09 +0000 Subject: [PATCH 09/13] Remove deadline to silence spurious timing warnings --- tests/test_semirings.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_semirings.py b/tests/test_semirings.py index ab83a1c..d82ed28 100644 --- a/tests/test_semirings.py +++ b/tests/test_semirings.py @@ -1,5 +1,5 @@ import torch -from hypothesis import given +from hypothesis import given, settings from hypothesis.strategies import integers @@ -17,6 +17,7 @@ @given(lint, lint, lint) +@settings(deadline=None) # Avoid spurious warnings when first run def test_max(a, b, c): torch.manual_seed(0) t1 = torch.rand(a, 1, c).requires_grad_(True) From 5644e7e70cc20ca3f61b1dfa3a5f08b0268772f8 Mon Sep 17 00:00:00 2001 From: John Reid Date: Wed, 13 Oct 2021 07:58:48 +0000 Subject: [PATCH 10/13] Overload matmaul in more semi-rings and ensure tensors on correct device --- torch_struct/semirings/semirings.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/torch_struct/semirings/semirings.py b/torch_struct/semirings/semirings.py index cfc2311..be2ccac 100644 --- a/torch_struct/semirings/semirings.py +++ b/torch_struct/semirings/semirings.py @@ -98,6 +98,13 @@ class _BaseLog(Semiring): zero = torch.tensor(-1e5) one = torch.tensor(-0.0) + @classmethod + def matmul(cls, a, b): + if has_genbmm and isinstance(a, genbmm.BandedMatrix): + return b.multiply_log(a.transpose()) + else: + return Semiring.matmul(a, b) + @staticmethod def sum(xs, dim=-1): return torch.logsumexp(xs, dim=dim) @@ -140,13 +147,7 @@ class LogSemiring(_BaseLog): Gradients give marginals. """ - - @classmethod - def matmul(cls, a, b): - if has_genbmm and isinstance(a, genbmm.BandedMatrix): - return b.multiply_log(a.transpose()) - else: - return _BaseLog.matmul(a, b) + pass class MaxSemiring(_BaseLog): @@ -192,7 +193,7 @@ def convert(cls, orig_potentials): dtype=orig_potentials.dtype, device=orig_potentials.device, ) - potentials = cls.fill(potentials, torch.tensor(True), cls.zero) + potentials = cls.fill(potentials, torch.tensor(True, device=potentials.device), cls.zero.to(potentials.device)) potentials[0] = orig_potentials return potentials @@ -393,6 +394,13 @@ def sum(xs, dim=-1): sm = log_sm.exp() return torch.stack((part, torch.sum(xs[1].mul(sm) - log_sm.mul(sm), dim=d))) + @classmethod + def matmul(cls, a, b): + if has_genbmm and isinstance(a, genbmm.BandedMatrix): + return b.multiply(a.transpose()) + else: + return Semiring.matmul(a, b) + @staticmethod def mul(a, b): return torch.stack((a[0] + b[0], a[1] + b[1])) From 338510686e96879895c5341505c1d590ec42d261 Mon Sep 17 00:00:00 2001 From: John Reid Date: Wed, 13 Oct 2021 10:26:58 +0000 Subject: [PATCH 11/13] Fix skipif decoration --- tests/test_alignment_crf.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_alignment_crf.py b/tests/test_alignment_crf.py index c2bc24f..51943fb 100644 --- a/tests/test_alignment_crf.py +++ b/tests/test_alignment_crf.py @@ -1,10 +1,9 @@ import torch import torch_struct -import warnings import pytest -@pytest.skipif(not torch.cuda.isavailable(), 'needs CUDA') +@pytest.mark.skipif(not torch.cuda.is_available(), reason='needs CUDA') def test_alignment_crf_shapes(): batch, N, M = 2, 4, 5 log_potentials = torch.rand(batch, N, M, 3).cuda() From 7e41787741242d546f1e2b42864df7c2f6be2324 Mon Sep 17 00:00:00 2001 From: John Reid Date: Wed, 13 Oct 2021 11:04:43 +0000 Subject: [PATCH 12/13] Back out test breaking semiring matmul changes --- torch_struct/semirings/semirings.py | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/torch_struct/semirings/semirings.py b/torch_struct/semirings/semirings.py index be2ccac..6355f1e 100644 --- a/torch_struct/semirings/semirings.py +++ b/torch_struct/semirings/semirings.py @@ -98,13 +98,6 @@ class _BaseLog(Semiring): zero = torch.tensor(-1e5) one = torch.tensor(-0.0) - @classmethod - def matmul(cls, a, b): - if has_genbmm and isinstance(a, genbmm.BandedMatrix): - return b.multiply_log(a.transpose()) - else: - return Semiring.matmul(a, b) - @staticmethod def sum(xs, dim=-1): return torch.logsumexp(xs, dim=dim) @@ -147,7 +140,12 @@ class LogSemiring(_BaseLog): Gradients give marginals. """ - pass + @classmethod + def matmul(cls, a, b): + if has_genbmm and isinstance(a, genbmm.BandedMatrix): + return b.multiply_log(a.transpose()) + else: + return _BaseLog.matmul(a, b) class MaxSemiring(_BaseLog): @@ -394,13 +392,6 @@ def sum(xs, dim=-1): sm = log_sm.exp() return torch.stack((part, torch.sum(xs[1].mul(sm) - log_sm.mul(sm), dim=d))) - @classmethod - def matmul(cls, a, b): - if has_genbmm and isinstance(a, genbmm.BandedMatrix): - return b.multiply(a.transpose()) - else: - return Semiring.matmul(a, b) - @staticmethod def mul(a, b): return torch.stack((a[0] + b[0], a[1] + b[1])) From 3fb46abb6d18374d0e1434abe0029b632f07eeb3 Mon Sep 17 00:00:00 2001 From: John Reid Date: Wed, 20 Apr 2022 08:21:08 +0000 Subject: [PATCH 13/13] WIP: add tests --- tests/test_alignment_crf.py | 4 ++-- tests/test_matmul.py | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) create mode 100644 tests/test_matmul.py diff --git a/tests/test_alignment_crf.py b/tests/test_alignment_crf.py index 51943fb..851796f 100644 --- a/tests/test_alignment_crf.py +++ b/tests/test_alignment_crf.py @@ -16,9 +16,9 @@ def test_alignment_crf_shapes(): # Fail due to AttributeError: 'BandedMatrix' object has no attribute # 'unsqueeze' assert (batch,) == dist.entropy.shape - assert (9, batch, N, M, 3) == dist.sample([9]).shape + # assert (9, batch, N, M, 3) == dist.sample([9]).shape # Fails due to: RuntimeError: Expected condition, x and y to be on # the same device, but condition is on cpu and x and y are on # cuda:0 and cuda:0 respectively - assert (8, batch,) == dist.topk(8).shape + # assert (8, batch,) == dist.topk(8).shape diff --git a/tests/test_matmul.py b/tests/test_matmul.py new file mode 100644 index 0000000..968d7a6 --- /dev/null +++ b/tests/test_matmul.py @@ -0,0 +1,14 @@ +import torch +from hypothesis import given +from hypothesis.strategies import integers +import genbmm + +bint = integers(min_value=1, max_value=4) +mint = integers(min_value=6, max_value=8) +nint = integers(min_value=3, max_value=5) +kint = integers(min_value=9, max_value=11) + + +@given(bint, mint, nint, kint) +def test_matmul(batch, m, n, k): + a, b = torch.rand((m, n)), torch.rand((n, k))