ott-jax / ott

Optimal transport tools implemented with the JAX framework, to get differentiable, parallel and jit-able computations.
https://ott-jax.readthedocs.io
Apache License 2.0
521 stars 80 forks source link

Increased GPU memory usage when using a cost_fn different from costs.SqEuclidean() #504

Closed felix0097 closed 3 weeks ago

felix0097 commented 7 months ago

Hi,

I have a question regarding the memory usage when using a cost_fn different from the default costs.SqEuclidean().

I have a large dataset (~240.000 datapoints in x and ~400.000 datapoints in y). If I use pointcloud.PointCloud with the default cost_fn and batch_size=512 everything works fine. However, if I use a different cost_fn e.g. the Cosine cost function, I run out of GPU memory. Even if I further reduce the batch_size (somehow the batch_size argument does not really seem to have an effect anymore).

This works:

geom = pointcloud.PointCloud(
    x, y, batch_size=512
)

ot_prob = linear_problem.LinearProblem(geom)
solver = sinkhorn.Sinkhorn()
ot = solver(ot_prob)

This runs out of GPU memory:

geom = pointcloud.PointCloud(
    x, y, 
    cost_fn=Cosine(),
    batch_size=512
)

ot_prob = linear_problem.LinearProblem(geom)
solver = sinkhorn.Sinkhorn()
ot = solver(ot_prob)

I'm using ott-jax==0.4.5

Thanks for your help! Felix

michalk8 commented 7 months ago

Strange, cosine should not have any more memory requirements when using batch_size? In the code above, did you try jitting the solver as solver = jax.jit(sinkhorn.Sinkhorn)? If not, this could lead to some possible memory optimizations.

michalk8 commented 7 months ago

As an alternative, there's a private method called _cosine_to_sqeucl which converts the cosine cost (defined as 1 - cosine_sim(x, y) to SqEucl cost.

marcocuturi commented 7 months ago

this is a detail, but if the jitting above suggested by Michal still fails, do you see a different behavior by inverting the arg positions of x and y?

The current parallelization we have implemented is only line-wise, the size of lines being 400k with your definition. So in practice, the matrices stored at each iteration are 400k x 512 ~ 40k x 5k ~ 20k x 10k which can start being a bit heavy depending on your GPU and on whether you are using float64.

felix0097 commented 7 months ago

I tried jitting the solver. But this didn't work - my notebook kernel just dies then. I think the issue might be that the code ignores the batch_size argument all together. I get the following error message:

2024-03-21 06:49:02.013401: W external[/tsl/tsl/framework/bfc_allocator.cc:291](http://localhost:8888/tsl/tsl/framework/bfc_allocator.cc#line=290)] Allocator (GPU_0_bfc) ran out of memory trying to allocate 377.03GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Cell In[16], line 3
      1 ot_prob = linear_problem.LinearProblem(geom)
      2 solver = sinkhorn.Sinkhorn()
----> 3 ot = solver(ot_prob)

File [/vol/data/miniconda3/envs/similarity/lib/python3.10/site-packages/ott/solvers/linear/sinkhorn.py:864](http://localhost:8888/lab/tree/git/dataset-similarity/notebooks/miniconda3/envs/similarity/lib/python3.10/site-packages/ott/solvers/linear/sinkhorn.py#line=863), in Sinkhorn.__call__(self, ot_prob, init, rng)
    860 initializer = self.create_initializer()
    861 init_dual_a, init_dual_b = initializer(
    862     ot_prob, *init, lse_mode=self.lse_mode, rng=rng
    863 )
--> 864 return run(ot_prob, self, (init_dual_a, init_dual_b))

File [/vol/data/miniconda3/envs/similarity/lib/python3.10/site-packages/ott/solvers/linear/sinkhorn.py:1141](http://localhost:8888/lab/tree/git/dataset-similarity/notebooks/miniconda3/envs/similarity/lib/python3.10/site-packages/ott/solvers/linear/sinkhorn.py#line=1140), in run(ot_prob, solver, init)
   1139 """Run loop of the solver, outputting a state upgraded to an output."""
   1140 iter_fun = _iterations_implicit if solver.implicit_diff else iterations
-> 1141 out = iter_fun(ot_prob, solver, init)
   1142 # Be careful here, the geom and the cost are injected at the end, where it
   1143 # does not interfere with the implicit differentiation.
   1144 out = out.set_cost(ot_prob, solver.lse_mode, solver.use_danskin)

    [... skipping hidden 5 frame]

File [/vol/data/miniconda3/envs/similarity/lib/python3.10/site-packages/ott/solvers/linear/sinkhorn.py:1178](http://localhost:8888/lab/tree/git/dataset-similarity/notebooks/miniconda3/envs/similarity/lib/python3.10/site-packages/ott/solvers/linear/sinkhorn.py#line=1177), in iterations(ot_prob, solver, init)
   1176 const = ot_prob, solver
   1177 state = solver.init_state(ot_prob, init)
-> 1178 state = fix_point(
   1179     cond_fn, body_fn, solver.min_iterations, solver.max_iterations,
   1180     solver.inner_iterations, const, state
   1181 )
   1182 return solver.output_from_state(ot_prob, state)

File [/vol/data/miniconda3/envs/similarity/lib/python3.10/site-packages/ott/math/fixed_point_loop.py:92](http://localhost:8888/lab/tree/git/dataset-similarity/notebooks/miniconda3/envs/similarity/lib/python3.10/site-packages/ott/math/fixed_point_loop.py#line=91), in fixpoint_iter(cond_fn, body_fn, min_iterations, max_iterations, inner_iterations, constants, state)
     86   (_, state), _ = jax.lax.scan(
     87       lambda carry, x: unrolled_body_fn(carry), (0, state),
     88       None,
     89       length=max_iterations // inner_iterations
     90   )
     91 else:
---> 92   _, state = jax.lax.while_loop(max_cond_fn, unrolled_body_fn, (0, state))
     93 return state

    [... skipping hidden 21 frame]

File /vol/data/miniconda3/envs/similarity/lib/python3.10/site-packages/jax/_src/compiler.py:237, in backend_compile(backend, module, options, host_callbacks)
    232   return backend.compile(built_c, compile_options=options,
    233                          host_callbacks=host_callbacks)
    234 # Some backends don't have `host_callbacks` option yet
    235 # TODO(sharadmv): remove this fallback when all backends allow `compile`
    236 # to take in `host_callbacks`
--> 237 return backend.compile(built_c, compile_options=options)

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 404834557696 bytes.

The code tries to allocate 377.03GiB of GPU memory. Which corresponds pretty much perfectly to the size of the full cost matrix using float32s: x.shape[0] * y.shape[0] * 4 / 1024**3 = 239696 * 422220 * 4 / 1024**3 = 377.02GiB.

Also, the 377.03GiB are independent of the batch_size I use. I can reduce the batch size and the code still tries to allocate the 377.03GiB of memory.

michalk8 commented 7 months ago

I looked at the test_sinkhorn_online_memory_jit and modified it with cost_fn=costs.Cosine(), it didn't increase much (7.6MiB -> 8.7MiB) on CPU, not sure exactly what's going on above on your system. @felix0097 What's our JAX/OTT-JAX version?

michalk8 commented 7 months ago

Also, could you please check the generated XLA code using jax.make_jaxpr to see whether it's indeed being materialized there?

felix0097 commented 7 months ago

I'm using jax==0.4.25 and ott-jax==0.4.5 @michalk8.

I've attached the out put of the jax.make_jaxpr function below. I'm not really familiar on how to interpret the results, but the pattern f32[422220,239696] shows up quite a few times: e.g line 188 onwards and 657 onwards.

make_jaxpr.txt

felix0097 commented 7 months ago

I tried the code above on a different system as well and have the same problem. I used the Jax container (version 23.10-py3) from NGC with jax==0.4.17.dev20231020 and jax-ott==0.4.4.

michalk8 commented 7 months ago

Ok, will take a closer look at this. For now, I recommend converting the cosine cost to sqeucl as mentioned above.

michalk8 commented 7 months ago

Seems like this is an issue with epsilon=None and it being computed:

```pytb Traceback (most recent call last): File "/mnt/task_runtime/test.py", line 13, in out = solve_fn(geom, min_iterations=1, max_iterations=1) File "/usr/local/lib/python3.10/dist-packages/ott/solvers/linear/_solve.py", line 60, in solve return solver(prob) File "/usr/local/lib/python3.10/dist-packages/ott/solvers/linear/sinkhorn.py", line 864, in __call__ return run(ot_prob, self, (init_dual_a, init_dual_b)) File "/usr/local/lib/python3.10/dist-packages/ott/solvers/linear/sinkhorn.py", line 1144, in run out = out.set_cost(ot_prob, solver.lse_mode, solver.use_danskin) File "/usr/local/lib/python3.10/dist-packages/ott/solvers/linear/sinkhorn.py", line 342, in set_cost return self.set(reg_ot_cost=compute_kl_reg_cost(f, g, ot_prob, lse_mode)) File "/usr/local/lib/python3.10/dist-packages/ott/solvers/linear/sinkhorn.py", line 256, in compute_kl_reg_cost fa = ot_prob.geom.potential_from_scaling(ot_prob.a) File "/usr/local/lib/python3.10/dist-packages/ott/geometry/geometry.py", line 519, in potential_from_scaling return self.epsilon * jnp.log(scaling) File "/usr/local/lib/python3.10/dist-packages/ott/geometry/geometry.py", line 162, in epsilon return self._epsilon.target File "/usr/local/lib/python3.10/dist-packages/ott/geometry/geometry.py", line 150, in _epsilon scale_eps = jax.lax.stop_gradient(self.mean_cost_matrix) File "/usr/local/lib/python3.10/dist-packages/ott/geometry/geometry.py", line 130, in mean_cost_matrix tmp = self._masked_geom().apply_cost(self._n_normed_ones).squeeze() File "/usr/local/lib/python3.10/dist-packages/ott/geometry/pointcloud.py", line 376, in apply_cost return self._apply_cost(arr, axis, fn=fn) File "/usr/local/lib/python3.10/dist-packages/ott/geometry/pointcloud.py", line 393, in _apply_cost return app( File "/usr/local/lib/python3.10/dist-packages/ott/geometry/pointcloud.py", line 784, in _apply_cost_xy c = _cost(x, y, norm_x, norm_y, cost_fn, scale_cost) File "/usr/local/lib/python3.10/dist-packages/ott/geometry/pointcloud.py", line 760, in _cost cost = norm_x + norm_y + one_line_pairwise(x, y) File "/usr/local/lib/python3.10/dist-packages/ott/geometry/costs.py", line 317, in pairwise cosine_similarity = jnp.vdot(x, y) / (x_norm * y_norm + self._ridge) jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 404834557696 bytes. ```

Passing the epsilon=<some float> fixes this, but this is not a good solution overall, we should be pre-computing the statistics before. However, this will require a lot of work and break the API in some places, though in general I support this and will think of a few ways how to do it most efficiently.