Skip to content

Can not batch ot.emd2 via torch.vmap  #532

Open
@oleg-kachan

Description

@oleg-kachan

Describe the bug

As my datapoints are empirical distributions I want to use the Wasserstein distance as a loss function over a batch of shape (n_batch, n_points, dimension). Standard way to make functions that take a batch as an input is torch.vmap, yet I get the error described below.

To Reproduce

def wasserstein2_loss(X, Y):
    n, m = X.shape[0], Y.shape[0]
    a = torch.ones(n) / n
    b = torch.ones(m) / m
    M = ot.dist(X, Y, metric="sqeuclidean")
    return ot.emd2(a, b, M) ** 0.5

wasserstein2_loss_batched = torch.vmap(wasserstein2_loss)
W2 = wasserstein2_loss_batched(X, Y) # should be an array of shape `n_batch`

Error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[6], line 1
----> 1 W2 = wasserstein2_loss_batched(X, Y)

File /usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py:434, in vmap.<locals>.wrapped(*args, **kwargs)
    430     return _chunked_vmap(func, flat_in_dims, chunks_flat_args,
    431                          args_spec, out_dims, randomness, **kwargs)
    433 # If chunk_size is not specified.
--> 434 return _flat_vmap(
    435     func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
    436 )

File /usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py:39, in doesnt_support_saved_tensors_hooks.<locals>.fn(*args, **kwargs)
     36 @functools.wraps(f)
     37 def fn(*args, **kwargs):
     38     with torch.autograd.graph.disable_saved_tensors_hooks(message):
---> 39         return f(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py:619, in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
    617 try:
    618     batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)
--> 619     batched_outputs = func(*batched_inputs, **kwargs)
    620     return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
    621 finally:

Cell In[4], line 13, in wasserstein2_loss(X, Y)
     11 b = torch.ones(m) / m
     12 M = ot.dist(X, Y, metric="sqeuclidean")
---> 13 return wasserstein_distance(a, b, M) ** 0.5

File /usr/local/lib/python3.10/dist-packages/ot/lp/__init__.py:488, in emd2(a, b, M, processes, numItermax, log, return_matrix, center_dual, numThreads, check_marginals)
    485 nx = get_backend(M0, a0, b0)
    487 # convert to numpy
--> 488 M, a, b = nx.to_numpy(M, a, b)
    490 a = np.asarray(a, dtype=np.float64)
    491 b = np.asarray(b, dtype=np.float64)

File /usr/local/lib/python3.10/dist-packages/ot/backend.py:207, in Backend.to_numpy(self, *arrays)
    205     return self._to_numpy(arrays[0])
    206 else:
--> 207     return [self._to_numpy(array) for array in arrays]

File /usr/local/lib/python3.10/dist-packages/ot/backend.py:207, in <listcomp>(.0)
    205     return self._to_numpy(arrays[0])
    206 else:
--> 207     return [self._to_numpy(array) for array in arrays]

File /usr/local/lib/python3.10/dist-packages/ot/backend.py:1763, in TorchBackend._to_numpy(self, a)
   1761 if isinstance(a, float) or isinstance(a, int) or isinstance(a, np.ndarray):
   1762     return np.array(a)
-> 1763 return a.cpu().detach().numpy()

RuntimeError: Cannot access data pointer of Tensor that doesn't have storage

Expected behavior

Make POT distance functions batchable via torch.vmap, seems Sinkhorn distance code has this problem too.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions