ott-jax / ott

Optimal transport tools implemented with the JAX framework, to get differentiable, parallel and jit-able computations.
https://ott-jax.readthedocs.io
Apache License 2.0
516 stars 79 forks source link

Sinkhorn slower on GPU than CPU #338

Closed gjhuizing closed 1 year ago

gjhuizing commented 1 year ago

Hi,

I've been using a Sinkhorn loss in my code, and it was quite fast on my laptop's CPU. But when moving to a GPU cluster, my code suddenly ran more than 10 times slower. This is not the case when using some other loss.

Here is a minimal example showcasing this behavior, which I ran on Colab with CPU & GPU. The idea is just to move a source pointcloud x2 to a target pointcloud x1 by minimizing an OT loss.

!pip install ott-jax
import jax
from ott.solvers.linear import sinkhorn
from ott.geometry import pointcloud
import matplotlib.pyplot as plt
import optax

# Confirms that we're on CPU or GPU
print(jax.devices())

# Generate a target and a source point cloud
key = jax.random.PRNGKey(0)
key1, key2 = jax.random.split(key)
x1 = jax.random.normal(key1, (15, 2))
x2 = jax.random.normal(key2, (15, 2))

# Define a loss function
def loss(y):
  geom_xy = pointcloud.PointCloud(x1, y)
  ot_cost = sinkhorn.solve(geom_xy).reg_ot_cost

  geom_yy = pointcloud.PointCloud(y, y)
  ot_cost -= 0.5 * sinkhorn.solve(geom_yy).reg_ot_cost

  return ot_cost

optimizer = optax.adam(1e-2)

@jax.jit
def step(y, opt_state):
  """A JITted step function"""
  loss_value, grads = jax.value_and_grad(loss)(y)
  updates, opt_state = optimizer.update(grads, opt_state, y)
  y = optax.apply_updates(y, updates)
  return y, opt_state, loss_value

# The optimization loop
y = x2
opt_state = optimizer.init(x2)
for i in range(1_000):
  y, opt_state, loss_value = step(y, opt_state)

As you can see in the Colabs a %%timeit on the optimization loop (after JIT compilation) gives

On CPU:   369 ms ± 6.04 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
On GPU:   8.64 s ± 258 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

I'm new to JAX, so maybe I'm using Optax or OTT wrong? Thanks a lot!

marcocuturi commented 1 year ago

Hi Geert!

do you see the same phenomenon as size increases? with point clouds of size 15 I would not expect gains on GPU (quite the contrary in fact) http://marcocuturi.net

On Fri, Mar 17, 2023 at 8:46 PM Geert-Jan Huizing @.***> wrote:

Hi,

I've been using a Sinkhorn loss in my code, and it was quite fast on my laptop's CPU. But when moving to a GPU cluster, my code suddenly ran more than 10 times slower. This is not the case when using some other loss.

Here is a minimal example showcasing this behavior, which I ran on Colab with CPU https://colab.research.google.com/drive/1ipUKAiIlnOJPijAv0OVmZjafppmTDYWN?usp=sharing & GPU https://colab.research.google.com/drive/1vR1MV-o6BEwAqHS8-SahycIRR7nySqEX?usp=sharing. The idea is just to move a source pointcloud x2 to a target pointcloud x1 by minimizing a,n OT loss.

!pip install ott-jaximport jaxfrom ott.solvers.linear import sinkhornfrom ott.geometry import pointcloudimport matplotlib.pyplot as pltimport optax

Confirms that we're on CPU or GPUprint(jax.devices())

Generate a target and a source point cloudkey = jax.random.PRNGKey(0)key1, key2 = jax.random.split(key)x1 = jax.random.normal(key1, (15, 2))x2 = jax.random.normal(key2, (15, 2))

Define a loss functiondef loss(y):

geom_xy = pointcloud.PointCloud(x1, y) ot_cost = sinkhorn.solve(geom_xy).reg_ot_cost

geom_yy = pointcloud.PointCloud(y, y) ot_cost -= 0.5 * sinkhorn.solve(geom_yy).reg_ot_cost

return ot_cost optimizer = optax.adam(1e-2) @jax.jitdef step(y, opt_state): """A JITted step function""" loss_value, grads = jax.value_and_grad(loss)(y) updates, opt_state = optimizer.update(grads, opt_state, y) y = optax.apply_updates(y, updates) return y, opt_state, loss_value

The optimization loopy = x2opt_state = optimizer.init(x2)for i in range(1_000):

y, opt_state, loss_value = step(y, opt_state)

As you can see in the Colabs a %%timeit on the optimization loop (after JIT compilation) gives

On CPU: 369 ms ± 6.04 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) On GPU: 8.64 s ± 258 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

I'm new to JAX, so maybe I'm using Optax or OTT wrong? Thanks a lot!

— Reply to this email directly, view it on GitHub https://github.com/ott-jax/ott/issues/338, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFMWFTIWWJPTGZJSATIUFYLW4S5RNANCNFSM6AAAAAAV65W6UQ . You are receiving this because you are subscribed to this thread.Message ID: @.***>

gjhuizing commented 1 year ago

Hi Marco, you're right, increasing to 100 points tips it in favor of GPU! And actually this is also the case for other losses, it was just too quick to spot the difference.

Closing this, thanks a lot!