diff --git a/test/scan/test_scan.py b/test/scan/test_scan.py index b61d8648fa2d..6023f302e8e9 100644 --- a/test/scan/test_scan.py +++ b/test/scan/test_scan.py @@ -75,12 +75,16 @@ def run_test(self, if v is not None else 0, tree_leaves(t)), torch.tensor(0.0)) dupe = lambda v: v.detach().clone().requires_grad_(v.requires_grad) + def _requires_grad(tensors): + return any(tree_flatten(tree_map(lambda v: v.requires_grad, tensors))[0]) + # Actual output init_scan = tree_map(dupe, init) xs_scan = tree_map(dupe, xs) final_carry, ys = scan(fn, init_scan, xs_scan, partition_fn=partition_fn) # Add up all leaves and `backward()` once. - (squish(final_carry) + squish(ys)).backward() + if _requires_grad(final_carry) or _requires_grad(ys): + (squish(final_carry) + squish(ys)).backward() torch_xla.sync() # Expected output @@ -88,7 +92,8 @@ def run_test(self, xs_loop = tree_map(dupe, xs) expected_final_carry, expected_ys = _loopy_scan(fn, init_loop, xs_loop) # Add up all leaves and `backward()` once. - (squish(expected_final_carry) + squish(expected_ys)).backward() + if _requires_grad(expected_final_carry) or _requires_grad(expected_ys): + (squish(expected_final_carry) + squish(expected_ys)).backward() torch_xla.sync() # Compare values @@ -126,6 +131,31 @@ def step_fn(carry, x): self.compare_pytree(expected_final_carry, final_carry) self.compare_pytree(expected_ys, ys) + def test_scan_long_tensor(self): + """This test uses `scan` to implement `torch.cumsum`.""" + + def step_fn(carry, x): + new_carry = carry + x + y = new_carry + return new_carry, y + + init = torch.tensor([0.0, 0.0], + requires_grad=False, + device=self.device, + dtype=torch.long) + xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], + requires_grad=False, + dtype=torch.long, + device=self.device) + final_carry, ys = self.run_test(step_fn, init, xs) + + # Also ensure that our loop-based scan is correct, with manual checks + # that replicate the step_fn. + expected_final_carry = torch.sum(xs, dim=0) + init + expected_ys = torch.cumsum(xs, dim=0) + self.compare_pytree(expected_final_carry, final_carry) + self.compare_pytree(expected_ys, ys) + def test_scan_fn_not_callable(self): init = torch.tensor([1.0, 1.0], device=self.device) xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], device=self.device) diff --git a/torch_xla/experimental/scan.py b/torch_xla/experimental/scan.py index c517e252c163..df4d6c3fa543 100644 --- a/torch_xla/experimental/scan.py +++ b/torch_xla/experimental/scan.py @@ -229,7 +229,8 @@ def make_fake_tensor(v: torch.Tensor, requires_grad=True) -> torch.Tensor: return torch.empty_like( v, dtype=v.dtype, device=v.device, requires_grad=requires_grad) - fake_carry_pytree = tree_map(make_fake_tensor, init) + fake_carry_pytree = tree_map( + lambda v: make_fake_tensor(v, requires_grad=v.is_floating_point()), init) fake_x_pytree = tree_map( lambda v: make_fake_tensor(v[0], requires_grad=v.requires_grad), xs) @@ -261,11 +262,15 @@ def fn_no_output_aliasing(*args): # intermediate activations. num_out = len(list(tree_iter(out))) # Capture the backward. - out, unflatten_fwd_out = tree_flatten_none(out) - torch.autograd.backward(out, tree_map(lambda v: torch.ones_like(v), out)) + flat_out, unflatten_fwd_out = tree_flatten_none(out) + out_with_grad = list(filter(lambda v: v.requires_grad, flat_out)) + if len(out_with_grad) > 0: + torch.autograd.backward( + out_with_grad, tree_map(lambda v: torch.ones_like(v), out_with_grad)) fwd_graph = get_fwd() - bwd_graph = get_bwd() + if len(out_with_grad) > 0: + bwd_graph = get_bwd() # Figure out which activations are aliases to the inputs. We don't need to # pass them through the scan logic unchanged. That would use more memory. @@ -320,6 +325,8 @@ def alias_input(partial_activations, xs): return tuple(activations) def backward(carry, x): + if len(out_with_grad) == 0: + return None, None grad_new_carry, _ = tree_flatten(carry) (grad_y, activations) = x grad_y, _ = tree_flatten_none(grad_y)