theislab / moscot

Multi-omic single-cell optimal transport tools
https://moscot-tools.org
BSD 3-Clause "New" or "Revised" License
104 stars 9 forks source link

`.impute` consumes too much memory #729

Open Marius1311 opened 2 weeks ago

Marius1311 commented 2 weeks ago

I'm trying to call problem.impute() on a solved (linear) spatial mapping problem of dimensions n_source=17806 (spatial data) by n_target=13298 (single-cell data) for n_genes=2039. This is just a full-rank Sinkhorn problem with batch_size=None.

Under the hood, this evaluates:

 predictions = [val.to(device=device).pull(gexp_sc, scale_by_marginals=True) for val in self.solutions.values()]

The pull amounts to a matrix multiplication: prediction = P @ X for transport matrix of shape 17806 x 13298 and single-cell GEX matrix X of shape 13298 x 2039. Thus, the memory bottleneck should be P, which is stored as float32 and should thus consume around 903 MB of memory. However, the call to impute fails (see traceback below) as it requests 1.76TiB of memory. That's because it tries to create an array of shape Shape: f32[2039,17806,13298], which is not needed for this operation.

Note that passing a batch size does not help much - let's say I'm passing batch_size=500, then this would still request an array of shape 2039 x 500 x 13298, which still requires over 50GB of memory. Also, this this slows down solving the actual OT problem, which would not be necessary from a memory point of view.

I talked to @michalk8 about this and it's probably a vmap that creates an array of the wrong shape. For now, we could solve this by evaluating the pull batch-wise over small sets of genes. This is inefficient, but would solve the issue for now.

If the transport matrix fits into CPU memory, then the current best way to go about this is materializing the transport matrix before calling impute:

for key, value in lmp.problems.items():
    value.solution.to(device="cpu")
    value.set_solution(np.array(value.solution.transport_matrix), overwrite=True)

That prevents the memory issue.

Traceback:

2024-07-05 10:45:20.572529: W external[/tsl/tsl/framework/bfc_allocator.cc:485](http://localhost:53807/tsl/tsl/framework/bfc_allocator.cc#line=484)] Allocator (GPU_0_bfc) ran out of memory trying to allocate 1.76TiB (rounded to 1931211837440)requested by op 
2024-07-05 10:45:20.572824: W external[/tsl/tsl/framework/bfc_allocator.cc:497](http://localhost:53807/tsl/tsl/framework/bfc_allocator.cc#line=496)] *****_______________________________________________________________________________________________
2024-07-05 10:45:20.572951: E external[/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732](http://localhost:53807/xla/xla/pjrt/pjrt_stream_executor_client.cc#line=2731)] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 1931211837328 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:  929.12MiB
              constant allocation:         0B
        maybe_live_out allocation:    1.76TiB
     preallocated temp allocation:         0B
                 total allocation:    1.76TiB
              total fragmentation:         0B (0.00%)
Peak buffers:
    Buffer 1:
        Size: 1.76TiB
        Operator: op_name="jit(_where)[/jit](http://localhost:53807/jit)(main)[/select_n](http://localhost:53807/select_n)" source_file="[/cluster/project/treutlein/USERS/mlange/github/moscot-fork/src/moscot/backends/ott/output.py](http://localhost:53807/lab/tree/github/spatial_analysis/analysis/experiments_and_tutorials/github/moscot-fork/src/moscot/backends/ott/output.py)" source_line=177
        XLA Label: fusion
        Shape: f32[2039,17806,13298]
        ==========================

    Buffer 2:
        Size: 903.26MiB
        Entry Parameter Subshape: f32[17806,13298]
        ==========================

    Buffer 3:
        Size: 25.86MiB
        Entry Parameter Subshape: pred[2039,1,13298]
        ==========================

    Buffer 4:
        Size: 4B
        Entry Parameter Subshape: f32[]
        ==========================
giovp commented 2 weeks ago

hi @Marius1311 , yes I observed this as well multiple times and reported it in private to @michalk8 as well

Note that passing a batch size does not help much - let's say I'm passing batch_size=500, then this would still request an array of shape 2039 x 500 x 13298, which still requires over 50GB of memory. Also, this this slows down solving the actual OT problem, which would not be necessary from a memory point of view.

I talked to @michalk8 about this and it's probably a vmap that creates an array of the wrong shape. For now, we could solve this by evaluating the pull batch-wise over small sets of genes. This is inefficient, but would solve the issue for now.

and yes, I also think that this is due to vmap. I think this is true also for GW problems and also not only for imputation but also e.g. for cell transition in my experience. Basically anywhere you want to apply the transport matrix

For now, we could solve this by evaluating the pull batch-wise over small sets of genes. This is inefficient, but would solve the issue for now.

this is a solution but would require considerable amount of work as there are various mixin methods that use that operation

Marius1311 commented 2 weeks ago

yes, I agree with you @giovp, batch-wise evaluation isn't really the way to go, this can only be a temporary fix. For me personally, materializing the transport matrix before calling .pull is the best solution, as long as the matrix fits into memory.