rentruewang / koila

Prevent PyTorch's `CUDA error: out of memory` in just 1 line of code.
https://rentruewang.com/koila/
Apache License 2.0
1.82k stars 62 forks source link

Stack overflow (endless loop) when gradients are disabled #19

Closed Tomsen1410 closed 2 years ago

Tomsen1410 commented 2 years ago

I've just installed and tried out koila. However there seems to be an endless loop when applying it to my backbone model. It uses Conv1d and gradients are disabled. Also it seems like koila does not handle the permute operation.

rentruewang commented 2 years ago

Thanks for the report!

TL;DR It'll be fixed when #18 is done.

The bug happens because I previously handled LazyTensor and Tensor differently using an if else. However, that sometimes interfere with how PyTorch handles custom tensor classes (by calling a special function). I'm currently working on #18 and it will solve this problem once and for all (and will make the code a bit cleaner). However, I'm really busy working with IRL stuff recently and hadn't had time to work on it. So sadly it's going to take a while.

Tomsen1410 commented 2 years ago

Thanks for the report!

TL;DR It'll be fixed when #18 is done.

The bug happens because I previously handled LazyTensor and Tensor differently using an if else. However, that sometimes interfere with how PyTorch handles custom tensor classes (by calling a special function). I'm currently working on #18 and it will solve this problem once and for all (and will make the code a bit cleaner). However, I'm really busy working with IRL stuff recently and hadn't had time to work on it. So sadly it's going to take a while.

Cool! Any way to quickly fix this myself? It seems like the function (conv1d in my case) calls __torch_function__, which calls lazy_forward, which directly calls the conv1d function again (because grads are disabled, why is it like that?), which calls __torch_function__ again and so on.

rentruewang commented 2 years ago

It seems that it's the same issue in #3, which should've been fixed in main branch. Is there a chance you're using the package installed from PyPI (pip)? The package up there isn't up to date because of the rewrite.

If that's the case, maybe you could try pip install git+https://github.com/rentruewang/koila@main.

Edit: I've pushed to PyPI. Now pip install --upgrade koila should work. Fell free to reopen if it doesn't!