ORNL / AADL

Anderson Acceleration for Deep Learning
12 stars 6 forks source link

Conserve gpu memory by storing history on cpu memory instead #3

Closed henrymai closed 3 years ago

henrymai commented 3 years ago

This patch offloads AADL history to the cpu memory instead of using valuable gpu memory.

This incurs a performance hit of transferring the vectors to and from cpu memory, but allows for training without reducing batch sizes with a smaller reduction in batch size than without the patch and not run out of memory.

This change also fixes a bug with torch.nn.utils.convert_parameters.vector_to_parameters where it does not preserve the memory_format of param.data.

History device offload is configurable by the user so that they can continue to use gpu memory for history if they prefer that for some reason instead (by using accelerate(..., history_device="cuda").

For reference, I get the following error without cpu memory offload after about like 90 iterations:

RuntimeError: CUDA out of memory. Tried to allocate 2.55 GiB (GPU 0; 24.00 GiB total capacity; 16.60 GiB already allocated; 1.82 GiB free; 19.75 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

With cpu memory offload I'm able to go 300+ iterations (at the same batch size as the failure scenario above).

henrymai commented 3 years ago

Hi @allaffa,

Thanks for merging this in, but this pull request included the change to line 19, you may want to revert that change specifically since you mentioned here: https://github.com/ORNL/AADL/issues/2#issuecomment-954783891 that the original line was correct.

But as I mentioned in the reply in that issue, if you revert my change to line 19 back to the original, the code will error out with R being undefined prior to its usage in line 19.

allaffa commented 3 years ago

HI @henrymai

The matrix R is indeed a typo. So I would rather not revert your PR and open a new one that overwrite the definition of the right-hand-side of the least-squares.

DX is the matrix where each columns contains the residuals at consecutive iterates. DR is the matrix where each columns contains the difference between residuals computed at consecutive iterates.

So the right hand side must beDX[:,-1]

This typo was not in the GitLab repository, but after a few changes done to make the code clear showed up in the first ORNL/GitHub relies. Thank you for pointing this out.

henrymai commented 3 years ago

Ok, I'll submit a new pull request to change the right hand side to DX[:,-1]

allaffa commented 3 years ago

@henrymai

Besides improvement in time, did you also notice any changes in the convergence of the training?

If the memory.contiguous format is NOT enforced as you proposed in the PR, the least-squares solved to compute the mixing coefficients for Anderson should still be correct, because the switch between contiguous vs. non-contiguous introduces a permutation on the rows of the matrix DR and the right hand side DX[-1]. However, the permutation on the rows does not change the solution to the least-squares.

My colleague @jqyin is running the new version of the code on SUMMIT and cannot reproduce the same results obtained with the old version of the code.

henrymai commented 3 years ago

Besides improvement in time, did you also notice any changes in the convergence of the training?

This patch was not for an improvement in time. The main thing this patch does is to store the history of parameters on the cpu side to save on gpu memory. As I mentioned in the commit message, it actually makes the time worse as there is overhead for transferring the parameters to the cpu memory (to append to the history on the cpu memory) and also overhead in transferring it back to the gpu to perform the anderson_acceleration calculations.

In regards to noticing a convergence difference, I don't have a baseline since the original would eat up too much gpu memory for me to actually train for long enough (and also it would throw an exception when it hit the typo), but as I mention below, there shouldn't be a functional difference.

If the memory.contiguous format is NOT enforced as you proposed in the PR, the least-squares solved to compute the mixing coefficients for Anderson should still be correct, because the switch between contiguous vs. non-contiguous introduces a permutation on the rows of the matrix DR and the right hand side DX[-1]. However, the permutation on the rows does not change the solution to the least-squares.

Right, contiguous or not it shouldn't change the results, its just that I get a warning an error from pytorch for .view() if I am using this with my model weights having memory_format=torch.channels_last.

My colleague @jqyin is running the new version of the code on SUMMIT and cannot reproduce the same results obtained with the old version of the code.

That is strange, it should be functionally equivalent to the original.

Could there be some other changes that might not have made it over from the gitlab repo to the github repo?

This is referring to your comment earlier:

This typo was not in the GitLab repository, but after a few changes done to make the code clear showed up in the first ORNL/GitHub relies. Thank you for pointing this out.

If possible, can your colleague try reverting my changes (and only fix the original typo) and see if they can reproduce the results using the original github code?

allaffa commented 3 years ago

This patch was not for an improvement in time. The main thing this patch does is to store the history of parameters on the cpu side to save on gpu memory. As I mentioned in the commit message, it actually makes the time worse as there is overhead for transferring the parameters to the cpu memory (to append to the history on the cpu memory) and also overhead in transferring it back to the gpu to perform the anderson_acceleration calculations.

Actually my colleague @jqyin is noticing that offloading to CPU memory and reloading to GPU every time instead of keeping everything stored on GPU reduced the time for training ResNet50 on ImageNet1k by a 3x factor.

In regards to noticing a convergence difference, I don't have a baseline since the original would eat up too much gpu memory for me to actually train for long enough (and also it would throw an exception when it hit the typo), but as I mention below, there shouldn't be a functional difference.

Can you give me some indication about what type of neural network architecture you are using, and what type of data you train the neural network on?

Right, contiguous or not it shouldn't change the results, its just that I get a warning from pytorch for .view() if I am using this with my model weights having memory_format=torch.channels_last.

My colleague @jqyin is running the new version of the code on SUMMIT and cannot reproduce the same results obtained with the old version of the code.

That is strange, it should be functionally equivalent to the original.

Could there be some other changes that might not have made it over from the gitlab repo to the github repo?

We are going to double check between the two version of the code and keep you informed about possible difference. If found any, we will open a PR to introduced missing pieces. Thanks.

henrymai commented 3 years ago

Actually my colleague @jqyin is noticing that offloading to CPU memory and reloading to GPU every time instead of keeping everything stored on GPU reduced the time for training ResNet50 on ImageNet1k by a 3x factor.

Wow, that is very interesting, I have a bunch of questions about Summit out of curiosity.

Is the latency on the NVLink bus between the POWER9 processors and the GPUs fairly low compared to pcie?

Did @jqyin increase the batch size due to having more gpu memory usable, to increase the training speed?

Also wondering if Summit has true unified memory between the cpus and gpus where no movement actually happens or if it is just "fake" unified memory where the transfers are managed behind the scenes by cuda?

I also have a hunch that eventually, in addition to history_device=torch.cpu, switching the compute_device=torch.cpu may increase the speed further on cpus that support AVX512 (due eliminating cpu -> gpu transfer overhead, while the cpu is still able to perform the calculations quickly using AVX512).

Can you give me some indication about what type of neural network architecture you are using, and what type of data you train the neural network on?

My model consists mostly of convolution layers and my inputs are 3x128x128 images.

EDIT: Also to be clear, I mean that I don't have a baseline for "pre-patch" AADL vs "post-patch" AADL in regards to convergence performance between the two. However, I can tell you that post-patch AADL converges similarly to how my model converges without AADL applied to the same optimizer that I'm using (NosAdam).

We are going to double check between the two version of the code and keep you informed about possible difference. If found any, we will open a PR to introduced missing pieces. Thanks.

Great, will be very interested to hear the results.

allaffa commented 2 years ago

@henrymai

My collaborator @jqyin tested a case where he set both history_device and compute_device to cpu, and ResNet50 with ImageNet1k seems to work fine with correct accuracy.
The accuracy is deteriorated when history_device is the CPU and compute device is the GPU.

henrymai commented 2 years ago

Are you guys using fp16 weights on the GPU?

allaffa commented 2 years ago

No. We are using fp32 single-precision floating-point.

henrymai commented 2 years ago

Just as an additional sanity check, have you guys also tested setting both history_device and compute_device to GPU?

allaffa commented 2 years ago

@henrymai

Both these two following situations do not work:

The only situation that works is the one we were already using before, with both history_device and compute_device on the CPU

henrymai commented 2 years ago

The only situation that works is the one we were already using before, with both history_device and compute_device on the CPU

Just to clarify, you mean before as in this comment: https://github.com/ORNL/AADL/pull/3#issuecomment-969015094 ,right?

Otherwise, if you meant before as in prior to my patch, that is actually the same as history_device = GPU and compute_device = GPU (assuming that you guys were training your models using GPU before).

Can you guys test locally reverting my patch git revert 23d8c0fdd7018bedbdb13677aeec59e0a038ea93 to confirm that it converges using the github version of the code, this should be equivalent to history_deivce = gpu and compute_device = gpu as I mentioned above, but we should confirm.

allaffa commented 2 years ago

When training ResNet50 on ImageNet1k, the older version of the code could run only on CPUs because the code would return an out of memory on the GPU. The new version of the code does not return out of memory anymore on the GPU. When the new code runs on the CPU, the results are consistent with the older version. When the new version fo the code uses the GPUs, the results are not consistent with the ones obtained running on CPU. We cannot compare the new code with the old code on GPUs because the old version of the code returns out of memory when running on GPU.

allaffa commented 2 years ago

@jqyin can you also confirm what is said before in response to @henrymai ?

henrymai commented 2 years ago

When training ResNet50 on ImageNet1k, the older version of the code could run only on CPUs because the code would return an out of memory on the GPU.

Oh ok, that makes sense now.

When the new code runs on the CPU, the results are consistent with the older version.

Ok, great, this confirms that my patch is indeed equivalent to before (just with added flexibility).

When the new version fo the code uses the GPUs, the results are not consistent with the ones obtained running on CPU

I think this means that the linear algebra routines being computed on the gpu are returning results different than what the cpu returns,

I think for now, the history_device and compute_device should be set to cpu (I can open a pull request for this, or you guys can do it). Then we should investigate the differences in the results of the linear algebra routines between the cpu and gpu (likely something that we will need pytorch to fix, assuming that you guys are on the latest version of pytorch already).

henrymai commented 2 years ago

Here's the pull request to default to cpu for compute_device for now: https://github.com/ORNL/AADL/pull/6/files

allaffa commented 2 years ago

I think that the PyTorch version that @jqyin is using on Summit is not the latest one. The machine is very complex, and the software stack is not very stable. So most of users need to install their own PyThon environment. I will ask @jqyin to provide specific details in that respect. The only thing I can do from my end (and I already discussed it with @jqyin) is to print the entries of the torch.tensors on screen when I train a DL model (not ResNet50 but some other smaller model) on the DGX machine with a fixed random seed. I want to compare the entries of the tensors between choosing history_device = CPU + compute_device = CPU and history_device = GPU + compute_device = GPU

henrymai commented 2 years ago

I think that the PyTorch version that @jqyin is using on Summit is not the latest one. The machine is very complex, and the software stack is not very stable. So most of users need to install their own PyThon environment.

Yeah, there have been a lot of fixes over time in pytorch to their linear algebra routines, so this might be the reason for the differences.

The only thing I can do from my end (and I already discussed it with @jqyin) is to print the entries of the torch.tensors on screen when I train a DL model (not ResNet50 but some other smaller model) on the DGX machine with a fixed random seed. I want to compare the entries of the tensors between choosing history_device = CPU + compute_device = CPU and history_device = GPU + compute_device = GPU

I can also try the same thing when I have time, but I will be doing it from the latest nvidia NGC pytorch container.

jqyin commented 2 years ago

for the record, I'm using PyTorch v1.9.0 on Summit for those experiments.

henrymai commented 2 years ago

Ok, that is actually fairly recent, I don't think 1.10 would have changed anything in regards to this.

henrymai commented 2 years ago

@allaffa

I tried to make a minimal self contained example (mnist) to reproduce the differences that you're seeing between AADL gpu vs AADL cpu.

I am able to see that there is a loss/accuracy difference between AADL gpu vs AADL cpu, but both of them still converge to an 80%+ test accuracy.

Meaning, the computations are definitely producing different results on cpu vs gpu, but they are not causing the gpu version to not converge in my example.

Git repo for the example: https://github.com/henrymai/aadl_testing

To rule out hardware differences causing problems, can you try running my small example above on your hardware to see if you see similar results as what I published in the git repo?

jqyin commented 2 years ago

@henrymai thanks for the test.

I run your test on Summit and got similar results:

jsrun -n1 -a1 -g1 python main.py --AADL_gpu Test set: Average loss: 0.3678, Accuracy: 8929/10000 (89%)

jsrun -n1 -a1 -g1 python main.py --AADL_cpu Test set: Average loss: 0.3116, Accuracy: 9085/10000 (91%)

jsrun -n1 -a1 -g1 python main.py Test set: Average loss: 0.3642, Accuracy: 8757/10000 (88%)

I guess for the complicated problem such as imagenet training and/or distributed training, the differences are magnified.

henrymai commented 2 years ago

Do you guys at ORNL happen to have a support contract/relationship with nvidia?

I think this problem is probably down to their cuBLAS/cuSolver kernels that pytorch is calling into.

They might also have suggestions on how to mitigate this problem.

allaffa commented 2 years ago

@henrymai

I wonder if the problem is general to all PyTorch or just to torch.linalg.lstsq. If the issue extends to the whole PyTorch, we are dealing with a large problem. If the issue is limited only to torch.linalg.lstsq, then we can easily patch the issue by avoiding the call to torch.linalg.lstsq for the time being.

This is possible by setting acceleration_type=anderson_normal_equation that computes the Anderson mixing by explicitly constructing the normal equation to solve the least-squares. On the one hand, this is somewhat risky (and strongly discouraged by the numerical linear algebra community) because the condition number is squared, but it circumvents the use of torch.linalg.lstsq.

henrymai commented 2 years ago

@allaffa

I actually almost immediately get an out of memory error when trying anderson_normal_equation:

@b2acb1bf91c8:aadl_testing$ python3 main.py --AADL_gpu
Device: cuda
AADL Device: cuda
Train Epoch: 1 [0/60000 (0%)]   Loss: 10.658417
Train Epoch: 1 [640/60000 (1%)] Loss: 25.446875
Train Epoch: 1 [1280/60000 (2%)]        Loss: 7.964392
Traceback (most recent call last):
  File "main.py", line 186, in <module>
    main()
  File "main.py", line 178, in main
    train(args, model, device, train_loader, optimizer, epoch, writer)
  File "main.py", line 57, in train
    optimizer.step()
  File "aadl_testing/AADL/accelerate.py", line 108, in averaged_accelerated_step
    acc_param = anderson.anderson_normal_equation(X, self.acc_relaxation, self.acc_reg)
  File "aadl_testing/AADL/anderson_acceleration.py", line 70, in anderson_normal_equation
    extr = X[:,-2] + DX[:,-1] - (DX[:,:-1]+DR)@gamma
RuntimeError: CUDA out of memory. Tried to allocate 2099.56 GiB (GPU 0; 24.00 GiB total capacity; 77.97 MiB already allocated; 21.39 GiB free; 92.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Same with even using --AADL_cpu

allaffa commented 2 years ago

This is line of code is exactly the same as the one used to update the solution with LSTQ. I don't understand why the out-of-memory is happening on this line.

henrymai commented 2 years ago

@allaffa I figured it out, here's the pull request to fix it: https://github.com/ORNL/AADL/pull/7/files

*Although, I'm not sure if the math remains correct though.

henrymai commented 2 years ago

@allaffa Here are the results for running anderson_normal_equation with my test:

python3 main.py --AADL_gpu
Test set: Average loss: 0.3993, Accuracy: 8529/10000 (85%)

python3 main.py --AADL_cpu
Test set: Average loss: 0.4911, Accuracy: 8361/10000 (84%)

Screenshot of tensorboard: https://i.imgur.com/RK2AMBp.png

allaffa commented 2 years ago

@henrymai Does the accuracy match with the the LSTQ or was the LSTQ leading to better accuracy?

henrymai commented 2 years ago

@allaffa Better for gpu but worse for cpu, the results with LSTQ are here: https://github.com/henrymai/aadl_testing

I'll paste them here too for easy reference:

python3 main.py --AADL_gpu
Test set: Average loss: 0.4970, Accuracy: 8135/10000 (81%)

python3 main.py --AADL_cpu
Test set: Average loss: 0.3923, Accuracy: 8682/10000 (87%)
allaffa commented 2 years ago

I'm asking @jqyin to double check how the normal equation behaves with the ResNet50 trained on ImageNet1k. The results seem to suggest:

jqyin commented 2 years ago

also, from the doc, the least square function in torch.linalg.lstsq actually uses different implementations on CPU and GPU: QR vs QR with pivoting. Maybe that causes the differences? unfortunately, there's only one implementation on GPU so we can't easily switch.

henrymai commented 2 years ago

@jqyin As I mentioned in an earlier comment, for the gpu implementation, its calling into cuBLAS/cuSolver, see this: https://github.com/pytorch/pytorch/pull/54725/files

Also see this tracking issue that they are looking to migrate more things over to using cuBLAS/cuSolver: https://github.com/pytorch/pytorch/issues/47953

Also as I mentioned earlier, I would suggest reaching out to nvidia developer support (if you guys have a contract/relationship with them) to see if they can investigate and provide solutions/mitigations.

nvidia actively contributes to pytorch, so they will probably be very helpful if you guys can reach out to them.