Closed felix0097 closed 3 weeks ago
Strange, cosine should not have any more memory requirements when using batch_size
? In the code above, did you try jitting
the solver as solver = jax.jit(sinkhorn.Sinkhorn)
? If not, this could lead to some possible memory optimizations.
As an alternative, there's a private method called _cosine_to_sqeucl
which converts the cosine cost (defined as 1 - cosine_sim(x, y)
to SqEucl cost.
this is a detail, but if the jitting above suggested by Michal still fails, do you see a different behavior by inverting the arg positions of x
and y
?
The current parallelization we have implemented is only line-wise, the size of lines being 400k with your definition. So in practice, the matrices stored at each iteration are 400k x 512 ~ 40k x 5k ~ 20k x 10k which can start being a bit heavy depending on your GPU and on whether you are using float64.
I tried jitting
the solver. But this didn't work - my notebook kernel just dies then. I think the issue might be that the code ignores the batch_size
argument all together. I get the following error message:
2024-03-21 06:49:02.013401: W external[/tsl/tsl/framework/bfc_allocator.cc:291](http://localhost:8888/tsl/tsl/framework/bfc_allocator.cc#line=290)] Allocator (GPU_0_bfc) ran out of memory trying to allocate 377.03GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
---------------------------------------------------------------------------
XlaRuntimeError Traceback (most recent call last)
Cell In[16], line 3
1 ot_prob = linear_problem.LinearProblem(geom)
2 solver = sinkhorn.Sinkhorn()
----> 3 ot = solver(ot_prob)
File [/vol/data/miniconda3/envs/similarity/lib/python3.10/site-packages/ott/solvers/linear/sinkhorn.py:864](http://localhost:8888/lab/tree/git/dataset-similarity/notebooks/miniconda3/envs/similarity/lib/python3.10/site-packages/ott/solvers/linear/sinkhorn.py#line=863), in Sinkhorn.__call__(self, ot_prob, init, rng)
860 initializer = self.create_initializer()
861 init_dual_a, init_dual_b = initializer(
862 ot_prob, *init, lse_mode=self.lse_mode, rng=rng
863 )
--> 864 return run(ot_prob, self, (init_dual_a, init_dual_b))
File [/vol/data/miniconda3/envs/similarity/lib/python3.10/site-packages/ott/solvers/linear/sinkhorn.py:1141](http://localhost:8888/lab/tree/git/dataset-similarity/notebooks/miniconda3/envs/similarity/lib/python3.10/site-packages/ott/solvers/linear/sinkhorn.py#line=1140), in run(ot_prob, solver, init)
1139 """Run loop of the solver, outputting a state upgraded to an output."""
1140 iter_fun = _iterations_implicit if solver.implicit_diff else iterations
-> 1141 out = iter_fun(ot_prob, solver, init)
1142 # Be careful here, the geom and the cost are injected at the end, where it
1143 # does not interfere with the implicit differentiation.
1144 out = out.set_cost(ot_prob, solver.lse_mode, solver.use_danskin)
[... skipping hidden 5 frame]
File [/vol/data/miniconda3/envs/similarity/lib/python3.10/site-packages/ott/solvers/linear/sinkhorn.py:1178](http://localhost:8888/lab/tree/git/dataset-similarity/notebooks/miniconda3/envs/similarity/lib/python3.10/site-packages/ott/solvers/linear/sinkhorn.py#line=1177), in iterations(ot_prob, solver, init)
1176 const = ot_prob, solver
1177 state = solver.init_state(ot_prob, init)
-> 1178 state = fix_point(
1179 cond_fn, body_fn, solver.min_iterations, solver.max_iterations,
1180 solver.inner_iterations, const, state
1181 )
1182 return solver.output_from_state(ot_prob, state)
File [/vol/data/miniconda3/envs/similarity/lib/python3.10/site-packages/ott/math/fixed_point_loop.py:92](http://localhost:8888/lab/tree/git/dataset-similarity/notebooks/miniconda3/envs/similarity/lib/python3.10/site-packages/ott/math/fixed_point_loop.py#line=91), in fixpoint_iter(cond_fn, body_fn, min_iterations, max_iterations, inner_iterations, constants, state)
86 (_, state), _ = jax.lax.scan(
87 lambda carry, x: unrolled_body_fn(carry), (0, state),
88 None,
89 length=max_iterations // inner_iterations
90 )
91 else:
---> 92 _, state = jax.lax.while_loop(max_cond_fn, unrolled_body_fn, (0, state))
93 return state
[... skipping hidden 21 frame]
File /vol/data/miniconda3/envs/similarity/lib/python3.10/site-packages/jax/_src/compiler.py:237, in backend_compile(backend, module, options, host_callbacks)
232 return backend.compile(built_c, compile_options=options,
233 host_callbacks=host_callbacks)
234 # Some backends don't have `host_callbacks` option yet
235 # TODO(sharadmv): remove this fallback when all backends allow `compile`
236 # to take in `host_callbacks`
--> 237 return backend.compile(built_c, compile_options=options)
XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 404834557696 bytes.
The code tries to allocate 377.03GiB
of GPU memory. Which corresponds pretty much perfectly to the size of the full cost matrix using float32s: x.shape[0] * y.shape[0] * 4 / 1024**3 = 239696 * 422220 * 4 / 1024**3 = 377.02GiB
.
Also, the 377.03GiB
are independent of the batch_size
I use. I can reduce the batch size and the code still tries to allocate the 377.03GiB
of memory.
I looked at the test_sinkhorn_online_memory_jit
and modified it with cost_fn=costs.Cosine()
, it didn't increase much (7.6MiB -> 8.7MiB) on CPU, not sure exactly what's going on above on your system.
@felix0097 What's our JAX/OTT-JAX version?
Also, could you please check the generated XLA code using jax.make_jaxpr to see whether it's indeed being materialized there?
I'm using jax==0.4.25
and ott-jax==0.4.5
@michalk8.
I've attached the out put of the jax.make_jaxpr function below. I'm not really familiar on how to interpret the results, but the pattern f32[422220,239696]
shows up quite a few times: e.g line 188 onwards and 657 onwards.
I tried the code above on a different system as well and have the same problem. I used the Jax container (version 23.10-py3
) from NGC with jax==0.4.17.dev20231020
and jax-ott==0.4.4
.
Ok, will take a closer look at this. For now, I recommend converting the cosine cost to sqeucl as mentioned above.
Seems like this is an issue with epsilon=None
and it being computed:
Passing the epsilon=<some float>
fixes this, but this is not a good solution overall, we should be pre-computing the statistics before. However, this will require a lot of work and break the API in some places, though in general I support this and will think of a few ways how to do it most efficiently.
Hi,
I have a question regarding the memory usage when using a cost_fn different from the default
costs.SqEuclidean()
.I have a large dataset (~
240.000
datapoints inx
and ~400.000
datapoints iny
). If I usepointcloud.PointCloud
with the defaultcost_fn
andbatch_size=512
everything works fine. However, if I use a differentcost_fn
e.g. theCosine
cost function, I run out of GPU memory. Even if I further reduce the batch_size (somehow the batch_size argument does not really seem to have an effect anymore).This works:
This runs out of GPU memory:
I'm using
ott-jax==0.4.5
Thanks for your help! Felix