elixir-explorer / explorer

Series (one-dimensional) and dataframes (two-dimensional) for fast and elegant data exploration in Elixir
https://hexdocs.pm/explorer
MIT License
1.12k stars 123 forks source link

`Nx.stack(..., axis: 1)` on a large dataframe can OOM #858

Closed billylanchantin closed 9 months ago

billylanchantin commented 9 months ago

I'm having an unexpected issue at work. I have a large dataframe:

DF.shape(df) #=> {15_653_051, 122}

This OOMs:

Nx.stack(df, axis: 1)

but this works:

n_rows = DF.n_rows(df)
num_slices = 10
size_slice = div(n_rows, num_slices)

0..(num_slices)
|> Stream.map(& &1 * size_slice)
|> Stream.map(fn start_index -> DF.slice(df, start_index, size_slice) end)
|> Stream.reject(&(DF.n_rows(&1) == 0))
|> Stream.map(&Nx.stack(&1, axis: 1))
|> Enum.reduce(&Nx.concatenate([&2, &1]))

It surprises me that the bare Nx.stack call consumes so much more memory than breaking apart the computation. Is this expected?

I apologize but I haven't had enough time to dig into the details. Also happy to close and move to the forums, slack, etc if preferred.

josevalim commented 9 months ago

That's definitely weird, it should be zero copy. Can you confirm that calling Enum.map(df.names, fn name -> df[name] end) does not cause a big impact on memory? And then what about Enum.map(df.names, fn name -> Series.to_tensor(df[name]) end)? Those should all be light weight references.

billylanchantin commented 9 months ago

@josevalim Yeah both are pretty instant.

That's definitely weird, it should be zero copy.

In this case I do expect some churn. Polars is a column store, and Nx.stack with axis: 1 is effectively doing a transpose.

josevalim commented 9 months ago

Which backend are you using? EXLA or BinaryBackend?

josevalim commented 9 months ago

Also, can you try Nx.concatenate instead of stack? I just want to rule something out quickly.

billylanchantin commented 9 months ago

BinaryBackend, unfortunately :(

My intention is to pass this to EXGBoost.train, but EXGBoost has some issues with non-binary backends.

I've tried using the EXLA backend -> serializing -> deserializing but that didn't work. (There were many steps to the process so I might've messed something up.)

billylanchantin commented 9 months ago

Oh sorry missed your earlier message.

Also, can you try Nx.concatenate instead of stack? I just want to rule something out quickly.

Will do. I'm already trying this in the background to see if it's Nx or Explorer:

transposed = Nx.stack(df) # worked pretty much instantly
Nx.transpose(transposed)  # currently working hard...

Will try your thing once it's done.

josevalim commented 9 months ago

In any case, I think we should make Nx.stack a default callback, so when invoked for EXLA, it does not perform several copies. I will open up a separate issue.

billylanchantin commented 9 months ago

Nx.transpose(transposed) OOM'd.

Nx.concatenate(df) succeeded pretty quickly (though it has shape {1_909_672_222}). Did you want me to pass any options to Nx.concatenate/2?

josevalim commented 9 months ago

I guess we can close this. We have optimized concatenate in Nx for the default backend and there is an open issue to optimize stack too.

billylanchantin commented 9 months ago

Agreed!

For posterity, these additions to Nx made stack(..., axis: 1) no longer OOM and take 3.5 mins on my VM:

My work around was taking ~11.5 min, so this is a great improvement.

Further improvements should come from addressing this issue: