Closed gjhuizing closed 1 year ago
Hi Geert!
do you see the same phenomenon as size increases? with point clouds of size 15 I would not expect gains on GPU (quite the contrary in fact) http://marcocuturi.net
On Fri, Mar 17, 2023 at 8:46 PM Geert-Jan Huizing @.***> wrote:
Hi,
I've been using a Sinkhorn loss in my code, and it was quite fast on my laptop's CPU. But when moving to a GPU cluster, my code suddenly ran more than 10 times slower. This is not the case when using some other loss.
Here is a minimal example showcasing this behavior, which I ran on Colab with CPU https://colab.research.google.com/drive/1ipUKAiIlnOJPijAv0OVmZjafppmTDYWN?usp=sharing & GPU https://colab.research.google.com/drive/1vR1MV-o6BEwAqHS8-SahycIRR7nySqEX?usp=sharing. The idea is just to move a source pointcloud x2 to a target pointcloud x1 by minimizing a,n OT loss.
!pip install ott-jaximport jaxfrom ott.solvers.linear import sinkhornfrom ott.geometry import pointcloudimport matplotlib.pyplot as pltimport optax
Confirms that we're on CPU or GPUprint(jax.devices())
Generate a target and a source point cloudkey = jax.random.PRNGKey(0)key1, key2 = jax.random.split(key)x1 = jax.random.normal(key1, (15, 2))x2 = jax.random.normal(key2, (15, 2))
Define a loss functiondef loss(y):
geom_xy = pointcloud.PointCloud(x1, y) ot_cost = sinkhorn.solve(geom_xy).reg_ot_cost
geom_yy = pointcloud.PointCloud(y, y) ot_cost -= 0.5 * sinkhorn.solve(geom_yy).reg_ot_cost
return ot_cost optimizer = optax.adam(1e-2) @jax.jitdef step(y, opt_state): """A JITted step function""" loss_value, grads = jax.value_and_grad(loss)(y) updates, opt_state = optimizer.update(grads, opt_state, y) y = optax.apply_updates(y, updates) return y, opt_state, loss_value
The optimization loopy = x2opt_state = optimizer.init(x2)for i in range(1_000):
y, opt_state, loss_value = step(y, opt_state)
As you can see in the Colabs a %%timeit on the optimization loop (after JIT compilation) gives
On CPU: 369 ms ± 6.04 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) On GPU: 8.64 s ± 258 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
I'm new to JAX, so maybe I'm using Optax or OTT wrong? Thanks a lot!
— Reply to this email directly, view it on GitHub https://github.com/ott-jax/ott/issues/338, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFMWFTIWWJPTGZJSATIUFYLW4S5RNANCNFSM6AAAAAAV65W6UQ . You are receiving this because you are subscribed to this thread.Message ID: @.***>
Hi Marco, you're right, increasing to 100 points tips it in favor of GPU! And actually this is also the case for other losses, it was just too quick to spot the difference.
Closing this, thanks a lot!
Hi,
I've been using a Sinkhorn loss in my code, and it was quite fast on my laptop's CPU. But when moving to a GPU cluster, my code suddenly ran more than 10 times slower. This is not the case when using some other loss.
Here is a minimal example showcasing this behavior, which I ran on Colab with CPU & GPU. The idea is just to move a source pointcloud
x2
to a target pointcloudx1
by minimizing an OT loss.As you can see in the Colabs a
%%timeit
on the optimization loop (after JIT compilation) givesI'm new to JAX, so maybe I'm using Optax or OTT wrong? Thanks a lot!