elixir-nx / scholar

Traditional machine learning on top of Nx
Apache License 2.0
426 stars 44 forks source link

Memory use for Affinity Propagation seems high, OOMing #171

Closed cigrainger closed 1 year ago

cigrainger commented 1 year ago

Hey! I'm running into an issue where I'm OOMing with Affinity Propagation w/ 5868 samples. I understand that memory use will be n^2 for the affinity matrix, but w/ f32 that should be 5868^2*4=137_733_696 bytes, no? Certainly shouldn't be blowing up the show. I snapped this before it crashed:

Screenshot 2023-09-04 at 08 32 25

Here's a Livebook to replicate: https://gist.github.com/cigrainger/4f2531aa84cd8e1b3593d6c718a4d1ae

For reference, with sklearn:

image

I assume there are plenty of optimisations in sklearn and comparisons may be a bit futile. Just not sure if there's a memory leak here or something. Thanks for the hard work on this lib and for any help!

Oh and if it matters: this is on an M1 Mac w/ 16gb ram.

msluszniak commented 1 year ago

Actually, the calculations are different at the moment. 5868^2 1024 4 ~ 141 GB. I'll investigate if we can change this somehow https://github.com/elixir-nx/scholar/blob/2274bc50f854af2f26769455898ec1101bd2e835/lib/scholar/cluster/affinity_propagation.ex#L310-L313

msluszniak commented 1 year ago

Ok, I have a solution. I'll push changes today ;)

msluszniak commented 1 year ago

The fix proposed in #172 reduces much of the problem but not the whole. Currently I'm able to calculate it for Nx.iota({3000, 1024}), but for {5000, 1024} this part https://github.com/elixir-nx/scholar/blob/2274bc50f854af2f26769455898ec1101bd2e835/lib/scholar/cluster/affinity_propagation.ex#L134-L163 still causes memory usage to rise and crashes the program for big input tensors. @polvalente maybe you have some suggestions on what we can improve in this loop?

cigrainger commented 1 year ago

Amazing :). Thank you! This part is beyond my pay grade. And yeah... 141gb! That's.... a lot.

msluszniak commented 1 year ago

Okay, at this point, I have no idea how we can reduce further memory usage in this algorithm, so I'm merging the fix and closing for now ;)