From 280521f383de746f29da0b8fd4fb9b8b7e452a66 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Mon, 14 Jul 2025 13:21:32 -0700 Subject: [PATCH 01/13] Fix 2.8 issue per sample grad --- intermediate_source/per_sample_grads.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/intermediate_source/per_sample_grads.py b/intermediate_source/per_sample_grads.py index ece80d3f94f..51bb61101bc 100644 --- a/intermediate_source/per_sample_grads.py +++ b/intermediate_source/per_sample_grads.py @@ -168,8 +168,23 @@ def compute_loss(params, buffers, sample, target): # we can double check that the results using ``grad`` and ``vmap`` match the # results of hand processing each one individually: -for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads.values()): - assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5) +# Replace the comparison section with this updated code +for name, ft_per_sample_grad in ft_per_sample_grads.items(): + # Find the corresponding manually computed gradient + idx = list(model.named_parameters()).index((name, model.get_parameter(name))) + per_sample_grad = per_sample_grads[idx] + + # Check if shapes match + if per_sample_grad.shape != ft_per_sample_grad.shape: + print(f"Shape mismatch for {name}: {per_sample_grad.shape} vs {ft_per_sample_grad.shape}") + # Reshape if needed (sometimes functional API returns different shape) + if per_sample_grad.numel() == ft_per_sample_grad.numel(): + ft_per_sample_grad = ft_per_sample_grad.view(per_sample_grad.shape) + + # Use a higher tolerance for comparison + assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=1e-2, rtol=1e-2), \ + f"Mismatch in {name}: max diff {(per_sample_grad - ft_per_sample_grad).abs().max().item()}" + ###################################################################### # A quick note: there are limitations around what types of functions can be From 19e68c822bfa216d3b5dae16dfefe5e63da8a6a2 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Mon, 14 Jul 2025 13:41:25 -0700 Subject: [PATCH 02/13] Update per_sample_grads.py --- intermediate_source/per_sample_grads.py | 1 - 1 file changed, 1 deletion(-) diff --git a/intermediate_source/per_sample_grads.py b/intermediate_source/per_sample_grads.py index 51bb61101bc..63061f8bc2f 100644 --- a/intermediate_source/per_sample_grads.py +++ b/intermediate_source/per_sample_grads.py @@ -168,7 +168,6 @@ def compute_loss(params, buffers, sample, target): # we can double check that the results using ``grad`` and ``vmap`` match the # results of hand processing each one individually: -# Replace the comparison section with this updated code for name, ft_per_sample_grad in ft_per_sample_grads.items(): # Find the corresponding manually computed gradient idx = list(model.named_parameters()).index((name, model.get_parameter(name))) From 311059c9a3e7d98ea1cf0da6523a6fbb77c0bcf7 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Mon, 14 Jul 2025 14:35:01 -0700 Subject: [PATCH 03/13] Fix pendulum.py issues, updated to use newer APIs --- advanced_source/pendulum.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/advanced_source/pendulum.py b/advanced_source/pendulum.py index 3084fe8312b..03aae4c3ec3 100644 --- a/advanced_source/pendulum.py +++ b/advanced_source/pendulum.py @@ -100,7 +100,7 @@ from tensordict.nn import TensorDictModule from torch import nn -from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec +from torchrl.data import Bounded, Composite, Unbounded from torchrl.envs import ( CatTensors, EnvBase, @@ -403,14 +403,14 @@ def _reset(self, tensordict): def _make_spec(self, td_params): # Under the hood, this will populate self.output_spec["observation"] - self.observation_spec = CompositeSpec( - th=BoundedTensorSpec( + self.observation_spec = Composite( + th=Bounded( low=-torch.pi, high=torch.pi, shape=(), dtype=torch.float32, ), - thdot=BoundedTensorSpec( + thdot=Bounded( low=-td_params["params", "max_speed"], high=td_params["params", "max_speed"], shape=(), @@ -426,24 +426,26 @@ def _make_spec(self, td_params): self.state_spec = self.observation_spec.clone() # action-spec will be automatically wrapped in input_spec when # `self.action_spec = spec` will be called supported - self.action_spec = BoundedTensorSpec( + self.action_spec = Bounded( low=-td_params["params", "max_torque"], high=td_params["params", "max_torque"], shape=(1,), dtype=torch.float32, ) - self.reward_spec = UnboundedContinuousTensorSpec(shape=(*td_params.shape, 1)) + self.reward_spec = Unbounded(shape=(*td_params.shape, 1)) def make_composite_from_td(td): # custom function to convert a ``tensordict`` in a similar spec structure # of unbounded values. - composite = CompositeSpec( + composite = Composite( { - key: make_composite_from_td(tensor) - if isinstance(tensor, TensorDictBase) - else UnboundedContinuousTensorSpec( - dtype=tensor.dtype, device=tensor.device, shape=tensor.shape + key: ( + make_composite_from_td(tensor) + if isinstance(tensor, TensorDictBase) + else Unbounded( + dtype=tensor.dtype, device=tensor.device, shape=tensor.shape + ) ) for key, tensor in td.items() }, @@ -687,7 +689,7 @@ def _reset( # is of type ``Composite`` @_apply_to_composite def transform_observation_spec(self, observation_spec): - return BoundedTensorSpec( + return Bounded( low=-1, high=1, shape=observation_spec.shape, @@ -711,7 +713,7 @@ def _reset( # is of type ``Composite`` @_apply_to_composite def transform_observation_spec(self, observation_spec): - return BoundedTensorSpec( + return Bounded( low=-1, high=1, shape=observation_spec.shape, From 9cce023977d1f779a271a8380dcf30235ef3b165 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Mon, 14 Jul 2025 16:03:27 -0700 Subject: [PATCH 04/13] Update intermediate_source/per_sample_grads.py --- intermediate_source/per_sample_grads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/intermediate_source/per_sample_grads.py b/intermediate_source/per_sample_grads.py index 63061f8bc2f..54b8414834e 100644 --- a/intermediate_source/per_sample_grads.py +++ b/intermediate_source/per_sample_grads.py @@ -169,7 +169,7 @@ def compute_loss(params, buffers, sample, target): # results of hand processing each one individually: for name, ft_per_sample_grad in ft_per_sample_grads.items(): - # Find the corresponding manually computed gradient + # Find the corresponding manually computed gradient. idx = list(model.named_parameters()).index((name, model.get_parameter(name))) per_sample_grad = per_sample_grads[idx] From 2d8bda9a80eeec8f8e96d4fca7988764eab5f4aa Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Tue, 15 Jul 2025 13:09:28 -0700 Subject: [PATCH 05/13] Update per_sample_grads.py --- intermediate_source/per_sample_grads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/intermediate_source/per_sample_grads.py b/intermediate_source/per_sample_grads.py index 54b8414834e..f5adea6a51d 100644 --- a/intermediate_source/per_sample_grads.py +++ b/intermediate_source/per_sample_grads.py @@ -181,7 +181,7 @@ def compute_loss(params, buffers, sample, target): ft_per_sample_grad = ft_per_sample_grad.view(per_sample_grad.shape) # Use a higher tolerance for comparison - assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=1e-2, rtol=1e-2), \ + assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=2e-2, rtol=2e-2), \ f"Mismatch in {name}: max diff {(per_sample_grad - ft_per_sample_grad).abs().max().item()}" From 5fc349e20c54a55f63b75325f1a2e73261113c4e Mon Sep 17 00:00:00 2001 From: sekyondaMeta <127536312+sekyondaMeta@users.noreply.github.com> Date: Wed, 16 Jul 2025 15:55:36 -0400 Subject: [PATCH 06/13] Update per_sample_grads.py --- intermediate_source/per_sample_grads.py | 26 ++++++++++--------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/intermediate_source/per_sample_grads.py b/intermediate_source/per_sample_grads.py index f5adea6a51d..db5496fdfd1 100644 --- a/intermediate_source/per_sample_grads.py +++ b/intermediate_source/per_sample_grads.py @@ -168,22 +168,16 @@ def compute_loss(params, buffers, sample, target): # we can double check that the results using ``grad`` and ``vmap`` match the # results of hand processing each one individually: -for name, ft_per_sample_grad in ft_per_sample_grads.items(): - # Find the corresponding manually computed gradient. - idx = list(model.named_parameters()).index((name, model.get_parameter(name))) - per_sample_grad = per_sample_grads[idx] - - # Check if shapes match - if per_sample_grad.shape != ft_per_sample_grad.shape: - print(f"Shape mismatch for {name}: {per_sample_grad.shape} vs {ft_per_sample_grad.shape}") - # Reshape if needed (sometimes functional API returns different shape) - if per_sample_grad.numel() == ft_per_sample_grad.numel(): - ft_per_sample_grad = ft_per_sample_grad.view(per_sample_grad.shape) - - # Use a higher tolerance for comparison - assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=2e-2, rtol=2e-2), \ - f"Mismatch in {name}: max diff {(per_sample_grad - ft_per_sample_grad).abs().max().item()}" - +# Get the parameter names in the same order as per_sample_grads +param_names = list(params.keys()) + +# Compare gradients for each parameter +for i, name in enumerate(param_names): + per_sample_grad = per_sample_grads[i] + ft_per_sample_grad = ft_per_sample_grads[name] + + assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5), \ + f"Gradients don't match for {name}: max diff = {(per_sample_grad - ft_per_sample_grad).abs().max()}" ###################################################################### # A quick note: there are limitations around what types of functions can be From d67bcb8f6a7a86e6914e4cc03f047446df10a433 Mon Sep 17 00:00:00 2001 From: sekyondaMeta <127536312+sekyondaMeta@users.noreply.github.com> Date: Wed, 16 Jul 2025 16:23:05 -0400 Subject: [PATCH 07/13] Update per_sample_grads.py --- intermediate_source/per_sample_grads.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/intermediate_source/per_sample_grads.py b/intermediate_source/per_sample_grads.py index db5496fdfd1..e95ba422554 100644 --- a/intermediate_source/per_sample_grads.py +++ b/intermediate_source/per_sample_grads.py @@ -169,15 +169,22 @@ def compute_loss(params, buffers, sample, target): # results of hand processing each one individually: # Get the parameter names in the same order as per_sample_grads -param_names = list(params.keys()) -# Compare gradients for each parameter -for i, name in enumerate(param_names): - per_sample_grad = per_sample_grads[i] - ft_per_sample_grad = ft_per_sample_grads[name] +for name, ft_per_sample_grad in ft_per_sample_grads.items(): + # Find the corresponding manually computed gradient + idx = list(model.named_parameters()).index((name, model.get_parameter(name))) + per_sample_grad = per_sample_grads[idx] + + # Check if shapes match and reshape if needed + if per_sample_grad.shape != ft_per_sample_grad.shape and per_sample_grad.numel() == ft_per_sample_grad.numel(): + ft_per_sample_grad = ft_per_sample_grad.view(per_sample_grad.shape) + + # Print differences instead of asserting + max_diff = (per_sample_grad - ft_per_sample_grad).abs().max().item() + print(f"Parameter {name}: max difference = {max_diff}") - assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5), \ - f"Gradients don't match for {name}: max diff = {(per_sample_grad - ft_per_sample_grad).abs().max()}" + # Optional: still assert for very large differences that might indicate real problems + assert max_diff < 0.5, f"Extremely large difference in {name}: {max_diff}" ###################################################################### # A quick note: there are limitations around what types of functions can be From bff32bd868a4c0990aa0a98d174da210b93e9029 Mon Sep 17 00:00:00 2001 From: sekyondaMeta <127536312+sekyondaMeta@users.noreply.github.com> Date: Sun, 20 Jul 2025 21:52:08 -0400 Subject: [PATCH 08/13] Update per_sample_grads.py --- intermediate_source/per_sample_grads.py | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/intermediate_source/per_sample_grads.py b/intermediate_source/per_sample_grads.py index e95ba422554..ece80d3f94f 100644 --- a/intermediate_source/per_sample_grads.py +++ b/intermediate_source/per_sample_grads.py @@ -168,23 +168,8 @@ def compute_loss(params, buffers, sample, target): # we can double check that the results using ``grad`` and ``vmap`` match the # results of hand processing each one individually: -# Get the parameter names in the same order as per_sample_grads - -for name, ft_per_sample_grad in ft_per_sample_grads.items(): - # Find the corresponding manually computed gradient - idx = list(model.named_parameters()).index((name, model.get_parameter(name))) - per_sample_grad = per_sample_grads[idx] - - # Check if shapes match and reshape if needed - if per_sample_grad.shape != ft_per_sample_grad.shape and per_sample_grad.numel() == ft_per_sample_grad.numel(): - ft_per_sample_grad = ft_per_sample_grad.view(per_sample_grad.shape) - - # Print differences instead of asserting - max_diff = (per_sample_grad - ft_per_sample_grad).abs().max().item() - print(f"Parameter {name}: max difference = {max_diff}") - - # Optional: still assert for very large differences that might indicate real problems - assert max_diff < 0.5, f"Extremely large difference in {name}: {max_diff}" +for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads.values()): + assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5) ###################################################################### # A quick note: there are limitations around what types of functions can be From de2609d72991f9b49bf0b4c46ab0159374476308 Mon Sep 17 00:00:00 2001 From: sekyondaMeta <127536312+sekyondaMeta@users.noreply.github.com> Date: Sun, 20 Jul 2025 22:02:09 -0400 Subject: [PATCH 09/13] Update per_sample_grads.py Testing with cpu --- intermediate_source/per_sample_grads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/intermediate_source/per_sample_grads.py b/intermediate_source/per_sample_grads.py index ece80d3f94f..fe63d3b98ae 100644 --- a/intermediate_source/per_sample_grads.py +++ b/intermediate_source/per_sample_grads.py @@ -52,7 +52,7 @@ def loss_fn(predictions, targets): # Let’s generate a batch of dummy data and pretend that we’re working with an MNIST dataset. # The dummy images are 28 by 28 and we use a minibatch of size 64. -device = 'cuda' +device = 'cpu' num_models = 10 batch_size = 64 From 785e38c1cfe19dadffa32103779572557fde1702 Mon Sep 17 00:00:00 2001 From: sekyondaMeta <127536312+sekyondaMeta@users.noreply.github.com> Date: Mon, 21 Jul 2025 08:17:46 -0400 Subject: [PATCH 10/13] Update per_sample_grads.py --- intermediate_source/per_sample_grads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/intermediate_source/per_sample_grads.py b/intermediate_source/per_sample_grads.py index fe63d3b98ae..ece80d3f94f 100644 --- a/intermediate_source/per_sample_grads.py +++ b/intermediate_source/per_sample_grads.py @@ -52,7 +52,7 @@ def loss_fn(predictions, targets): # Let’s generate a batch of dummy data and pretend that we’re working with an MNIST dataset. # The dummy images are 28 by 28 and we use a minibatch of size 64. -device = 'cpu' +device = 'cuda' num_models = 10 batch_size = 64 From 1e4f2512c09b959c779bdd05c489141896fde36d Mon Sep 17 00:00:00 2001 From: sekyondaMeta <127536312+sekyondaMeta@users.noreply.github.com> Date: Mon, 21 Jul 2025 13:50:45 -0400 Subject: [PATCH 11/13] Update per_sample_grads.py Printing differences on assertion fail --- intermediate_source/per_sample_grads.py | 31 ++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/intermediate_source/per_sample_grads.py b/intermediate_source/per_sample_grads.py index ece80d3f94f..03250a2b4fa 100644 --- a/intermediate_source/per_sample_grads.py +++ b/intermediate_source/per_sample_grads.py @@ -168,7 +168,36 @@ def compute_loss(params, buffers, sample, target): # we can double check that the results using ``grad`` and ``vmap`` match the # results of hand processing each one individually: -for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads.values()): +for i, (per_sample_grad, ft_per_sample_grad) in enumerate( + zip(per_sample_grads, ft_per_sample_grads.values()) +): + is_close = torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5) + if not is_close: + # Calculate and print the maximum absolute difference + abs_diff = (per_sample_grad - ft_per_sample_grad).abs() + max_diff = abs_diff.max().item() + mean_diff = abs_diff.mean().item() + print(f"Gradient {i} mismatch:") + print(f" Max absolute difference: {max_diff}") + print(f" Mean absolute difference: {mean_diff}") + print(f" Shape of tensors: {per_sample_grad.shape}") + # Print a sample of values from both tensors where the difference is largest + max_idx = abs_diff.argmax().item() + flat_idx = max_idx + if len(abs_diff.shape) > 1: + # Convert flat index to multi-dimensional index + indices = [] + temp_shape = abs_diff.shape + for dim in reversed(temp_shape): + indices.insert(0, flat_idx % dim) + flat_idx //= dim + print(f" Max difference at index: {indices}") + print(f" Manual gradient value: {per_sample_grad[tuple(indices)].item()}") + print( + f" Functional gradient value: {ft_per_sample_grad[tuple(indices)].item()}" + ) + + # Keep the original assertion assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5) ###################################################################### From 0bb46a4a7c3e4a4d185372221fecb5754f68fcca Mon Sep 17 00:00:00 2001 From: sekyondaMeta <127536312+sekyondaMeta@users.noreply.github.com> Date: Mon, 21 Jul 2025 15:03:39 -0400 Subject: [PATCH 12/13] Update per_sample_grads.py float64 baseline comparison --- intermediate_source/per_sample_grads.py | 57 ++++++++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/intermediate_source/per_sample_grads.py b/intermediate_source/per_sample_grads.py index 03250a2b4fa..23ce775faec 100644 --- a/intermediate_source/per_sample_grads.py +++ b/intermediate_source/per_sample_grads.py @@ -168,6 +168,36 @@ def compute_loss(params, buffers, sample, target): # we can double check that the results using ``grad`` and ``vmap`` match the # results of hand processing each one individually: +# Create a float64 baseline for more precise comparison +def compute_grad_fp64(sample, target): + # Convert to float64 for higher precision + sample_fp64 = sample.to(torch.float64) + target_fp64 = target + + # Create a float64 version of the model + model_fp64 = SimpleCNN().to(device=device) + # Copy parameters from original model to float64 model + with torch.no_grad(): + for param_fp32, param_fp64 in zip(model.parameters(), model_fp64.parameters()): + param_fp64.copy_(param_fp32.to(torch.float64)) + + sample_fp64 = sample_fp64.unsqueeze(0) # prepend batch dimension + target_fp64 = target_fp64.unsqueeze(0) + + prediction = model_fp64(sample_fp64) + loss = loss_fn(prediction, target_fp64) + + return torch.autograd.grad(loss, list(model_fp64.parameters())) + + +def compute_fp64_baseline(data, targets, indices): + """Compute float64 gradient for a specific sample""" + # Only compute for the sample with the largest difference to save computation + i = indices[0] # Sample index + sample_grad = compute_grad_fp64(data[i], targets[i]) + return sample_grad + + for i, (per_sample_grad, ft_per_sample_grad) in enumerate( zip(per_sample_grads, ft_per_sample_grads.values()) ): @@ -181,7 +211,8 @@ def compute_loss(params, buffers, sample, target): print(f" Max absolute difference: {max_diff}") print(f" Mean absolute difference: {mean_diff}") print(f" Shape of tensors: {per_sample_grad.shape}") - # Print a sample of values from both tensors where the difference is largest + + # Find the location of maximum difference max_idx = abs_diff.argmax().item() flat_idx = max_idx if len(abs_diff.shape) > 1: @@ -197,6 +228,30 @@ def compute_loss(params, buffers, sample, target): f" Functional gradient value: {ft_per_sample_grad[tuple(indices)].item()}" ) + # Compute float64 baseline for the sample with the largest difference + print("\nComputing float64 baseline for comparison...") + try: + fp64_grads = compute_fp64_baseline(data, targets, indices) + fp64_value = fp64_grads[i][ + tuple(indices[1:]) + ].item() # Skip batch dimension + print(f" Float64 baseline value: {fp64_value}") + + # Compare both methods against float64 baseline + manual_diff = abs(per_sample_grad[tuple(indices)].item() - fp64_value) + functional_diff = abs( + ft_per_sample_grad[tuple(indices)].item() - fp64_value + ) + print(f" Manual method vs float64 difference: {manual_diff}") + print(f" Functional method vs float64 difference: {functional_diff}") + + if manual_diff < functional_diff: + print(" Manual method is closer to float64 baseline") + else: + print(" Functional method is closer to float64 baseline") + except Exception as e: + print(f" Error computing float64 baseline: {e}") + # Keep the original assertion assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5) From 2c12321828dbcf650af69b28dd7425a3a18445e4 Mon Sep 17 00:00:00 2001 From: sekyondaMeta <127536312+sekyondaMeta@users.noreply.github.com> Date: Mon, 21 Jul 2025 15:34:22 -0400 Subject: [PATCH 13/13] Update per_sample_grads.py --- intermediate_source/per_sample_grads.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/intermediate_source/per_sample_grads.py b/intermediate_source/per_sample_grads.py index 23ce775faec..f2c9a7d7129 100644 --- a/intermediate_source/per_sample_grads.py +++ b/intermediate_source/per_sample_grads.py @@ -174,12 +174,10 @@ def compute_grad_fp64(sample, target): sample_fp64 = sample.to(torch.float64) target_fp64 = target - # Create a float64 version of the model - model_fp64 = SimpleCNN().to(device=device) - # Copy parameters from original model to float64 model - with torch.no_grad(): - for param_fp32, param_fp64 in zip(model.parameters(), model_fp64.parameters()): - param_fp64.copy_(param_fp32.to(torch.float64)) + # Create a float64 version of the model and explicitly convert it to float64 + model_fp64 = SimpleCNN().to(device=device).to(torch.float64) + + # No need to manually copy parameters as the model is already in float64 sample_fp64 = sample_fp64.unsqueeze(0) # prepend batch dimension target_fp64 = target_fp64.unsqueeze(0) @@ -254,7 +252,6 @@ def compute_fp64_baseline(data, targets, indices): # Keep the original assertion assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5) - ###################################################################### # A quick note: there are limitations around what types of functions can be # transformed by ``vmap``. The best functions to transform are ones that are pure