huggingface / candle

Minimalist ML framework for Rust
Apache License 2.0
15.04k stars 877 forks source link

How to reduce memory usage of backpropagation? #1241

Open laptou opened 10 months ago

laptou commented 10 months ago

I implemented the tiny NeRF example using candle here: https://github.com/laptou/nerfy/blob/fc50dbd61c4012d1f12f556a72474b59a8b3c158/examples/tiny_nerf.rs

The example, which is written using TensorFlow, runs fine on my laptop. My candle implementation consumes all available memory on my laptop, which crashes my desktop session if I use CPU and errors out with a CUDA memory allocation error if I use the GPU. I'm running on a laptop with 32 GB of RAM, 32 GB of swap, and an RTX A3000 w/ 12 GB of VRAM.

I'm barely able to run it on CPU if I decrease the hidden layer size from 256 to 64.

image

I tracked the memory allocations using heaptrack, and it seems like most of them are related to keeping track of the operations for backpropagation.

Can you spot any obvious issues in my implementation that are causing it to consume so much memory? Is there a way that I can disable or reduce this behavior in some parts of the code to reduce the amount of memory that it uses?

laptou commented 10 months ago

image

Here's a screenshot from heaptrack that I am using to blame backpropagation for my memory woes.

The program errors out with Error: Cuda(Cuda(DriverError(CUDA_ERROR_OUT_OF_MEMORY, "out of memory"))) almost immediately if I try to use CUDA, so I think that there are additional issues at play here besides the backpropagation impl possibly being space-inefficient.

laptou commented 10 months ago

It seems like the reference count of a tensor (and therefore its Storage) never reaches zero until it can no longer be used for backpropagation at all. Seems related to this comment.

laptou commented 10 months ago

I reimplemented Tiny NeRF using tch. Neither of them work on CUDA using f64. The tch version works much better using f32, but the candle version is not made much better (still doesn't work on CUDA). On CPU, memory usage is significantly lower for both f32 and f64 tch implementations than their corresponding candle implementations.

LaurentMazare commented 10 months ago

Interesting, the memory usage of candle during backprop is certainly not optimal. Could you maybe measure the memory footprint after doing the forward pass but before doing the backward pass? If these are on par between tch and candle it would mean that the tracking during the backprop is the issue - this one should be easy to improve. If these are already very different, I'm not really sure what would be going on - maybe somehow tch notices that some part of the graph are not necessary for backprop but I'm not sure which part so trimming down the example might be a good thing (it's certainly a bit long at the moment). Also obviously all this should only be an issue if you care about backprop, if you cared about running inference it's easy to disable the gradient tracking and so the memory footprint should be reduced a lot.

LaurentMazare commented 10 months ago

Also could you give it a try with #1243 ? This should only improve memory consumption on the backprop part avoiding that all the intermediary values get stored until the end of the backprop.

LaurentMazare commented 10 months ago

When I try with #1243, I can successfully set the batch size back to 16. Thank you very much. That seems to have solved my problem. But I would like to ask, does it make a difference when training with or without grad.detach()?

Using #1243 shouldn't make a difference as long as you're not doing very hacky things such as taking the second order derivative by applying grad to grad. If you just have a normal optimizer loop with SGD or AdamW I would expect the behavior to be exactly identical.

antodyn commented 10 months ago

When I try with #1243, I can successfully set the batch size back to 16. Thank you very much. That seems to have solved my problem. But I would like to ask, does it make a difference when training with or without grad.detach()?

Using #1243 shouldn't make a difference as long as you're not doing very hacky things such as taking the second order derivative by applying grad to grad. If you just have a normal optimizer loop with SGD or AdamW I would expect the behavior to be exactly identical.

When I try with #1243, I can successfully set the batch size back to 16. Thank you very much. That seems to have solved my problem. But I would like to ask, does it make a difference when training with or without grad.detach()?

Using #1243 shouldn't make a difference as long as you're not doing very hacky things such as taking the second order derivative by applying grad to grad. If you just have a normal optimizer loop with SGD or AdamW I would expect the behavior to be exactly identical.

Thanks a lot. That really helped me.

LaurentMazare commented 10 months ago

I've merged #1243 as it seems to fix things for you. @laptou would be great to have your experience with this too! I wouldn't expect the new behavior to be different, and the additional compute cost should be very small (tensors are detached lazily, i.e. nothing happens if they were already detached which will often be the case). However if you notice anything weird you can get the old behavior back by setting the following environment variable: CANDLE_GRAD_DO_NOT_DETACH=1.

laptou commented 10 months ago

I'll test it out soon!

laptou commented 10 months ago

Just tested it, here are my observations:

torch cuda f32 8x256: mem usage = 6398 MiB, train time = 4m30s
torch cuda f64 8x256: crash due to gpu oom
torch cpu f32 8x256: mem usage = 6304 MiB, train time = 74m17s

before merge:
candle cuda f32 8x256: crash due to gpu oom during forward pass
candle cuda f32 8x64: crash due to gpu oom at backpropagation step

after merge:
candle cuda f32 8x256: crash due to gpu oom during forward pass (running the network forward uses way too much memory)
candle cuda f32 8x64: peak mem usage = 9426 MiB, train time = 4m6s
running the network forward on a single batch makes gpu mem increase by 1966 MiB to store a 65536x4 f32 array 1886 -> 3852, probably b/c the intermediate tensors aren't deleted. OOMs on 8x256 b/c there are too many batches held in memory at once and each batch consumes a ton of memory
LaurentMazare commented 10 months ago

Thanks for trying this out, that's very interesting. Do you have a sense of why the pytorch model wouldn't have to retain the intermediary values? It could be that we have ops on the candle side that gets split into multiple sub-ops and w rtain every of these whereas in PyTorch these only result in a single op that gets retained. It could also be a difference in the model but I imagine you've made them pretty similar - also you should ensure that the set of variables in the VarMap only contains the variable that you want the optimizer to work on so as not to trigger unrequired dependencies. Could you measure the PyTorch memory usage for 8x64? Also maybe measure them in the forward pass only if that's easy?

laptou commented 10 months ago

Sure.

I attached a debugger and placed a breakpoint on the line that calls Optimizer::backward_step, then measured the memory usage using nvidia-smi before and after calling it.

The implementation of the two models is as 1:1 as I could get them. Here's the diff.

LaurentMazare commented 10 months ago

Thanks, hard to tell by just looking at the code as the model seems pretty involved. A way to optimize this a bit would probably be to use heaptrack as you already did while running on cpu just for the forward pass and see if there are any ops that could be optimized to reduce allocation. The most likely thing here is that some part of the architecture gets broken down into lots of ops on the candle side and we retain values for all these intermediary steps whereas on the PyTorch side these parts may be able to use a far smaller number of ops.

iwanders commented 1 month ago

I ran into the same high memory usage trying to train my implementation of FCN, using pretty much twice the memory of what pytorch uses:

Before forward:  1250MiB/7930MiB
After  forward:  3138MiB/7930MiB
Before backward: 3102MiB/7930MiB
After backward:  7841MiB/7930MiB

Vram usage captured with this code, I then proceeded with annotating the main backward method to see where the memory jumps:

vram_through_backpropagation.txt

It appears to creep up with each operation: majority of increase happens in Op::Conv2D backward propagation for my network, one instance for example shows:

[ 3449MiB/7930MiB ] Remove node: TensorId(12152) Tensor[dims 3, 4096, 9, 9; f32, cuda:0]
[ 3449MiB/7930MiB ] Detached grad:: TensorId(12373)
[ 3449MiB/7930MiB ] Start of Op::Conv2D
[ 4633MiB/7930MiB ] End of Op::Conv2D
[ 4633MiB/7930MiB ] After node iter: 14

This shows a large jump in backprop.rs when the Op::Conv2D is handled. But we also see it increase by 32mb in the addition just following the detached 12282 tensor.

After the backpropagation finishes, the total size of the gradients is 556408028 bytes, so 550mb, which is an expected amount, but the overall vram grew by more than a factor of two to over 7 GB.

I proceeded to write a singleton to track tensor creation and deletion, but unfortunately those backtraces do not point at a smoking gun, because tensors appear to be dropped with the strong count to the storage still being more than one. Likely because storage can be shared between tensors? The backtraces do point at the math operations that are throughout the back propagation calculation itself...

I did take a stab at sprinkling detach() throughout the backprop function in e357f2bd58d3690a7bf52adae38e577becac792b , which did reduce memory usage by 2 GB:

Before forward:   1264MiB/7930MiB
After  forward:   3152MiB/7930MiB
Before backward:  3120MiB/7930MiB
After backward:   5816MiB/7930MiB

but then thought that there must be a better / more correct way to do this, perhaps with a thread local bool or something that allows us to disable back propagation information while some RAII object is held?

I've already spent quite a bit of time trying to understand the issue and attempt to solve it, and I'm not sure how much more time I want to commit, so thought I'd share my findings thus far.