ltatzel / PyTorchHessianFree

PyTorch implementation of the Hessian-free optimizer
BSD 3-Clause "New" or "Revised" License
30 stars 6 forks source link

Avoid storing forward lists #1

Closed ltatzel closed 2 years ago

ltatzel commented 2 years ago

Issue: In the original version of acc_step, the computation of the mini-batch losses and outputs happened in the _forward_lists method. These losses and outputs were stored in lists and were used to accumulate the total loss/gradient/matrix-vector product in a subsequent step. I decided on this approach because it allowed using the lists multiple times. The problem with this is that the lists can become quite large and block a lot of memory on the GPU. This PR addresses this issue.

Solution: Instead of storing the computed losses/outputs in lists, we only compute the loss/outputs for a given mini-batch and compute the requested quantity (mini-batch loss, gradient or matrix-vector product) right away. This was achieved by "merging" the _forward_lists and the _acc method. The losses/outputs thus no longer need to be stored - less memory is used. The downside of this approach is that the computed mini-batch loss/outputs can not be used multiple times. This results in redundant work, e.g. if all quantities are computed on the same data, the same forward pass is executed multiple times.

This "solution" is not ideal and can certainly be further optimized...