ltatzel / PyTorchHessianFree

PyTorch implementation of the Hessian-free optimizer
BSD 3-Clause "New" or "Revised" License
30 stars 6 forks source link
conjugate-gradient deep-learning deep-neural-networks ggn hessian-free matrix-free newtons-method optimization-algorithms pytorch-implementation second-order-optimization stochastic-optimizers

PyTorchHessianFree

Here, we provide a PyTorch implementation of the Hessian-free optimizer as described in [1] and [2] (see below). This project is currently still being developed, so changes may be made at any time.

Core idea: At each step, the optimizer computes a local quadratic approximation of the target function and uses the conjugate gradient (cg) method to approximate its minimum (the Newton step). This method only requires access to matrix-vector products with the curvature matrix, which can be done without creating this matrix in memory explicitly. This makes the Hessian-free optimizer applicable for large problems with high-dimensional parameter spaces (e.g. training neural networks).

Credits: The pytorch-hessianfree repo by GitHub-user fmeirinhos served as a starting point. For the matrix-vector products with the Hessian or GGN, we use functionality from the BackPACK package [3].

Table of contents:

  1. Installation instructions
  2. Example
  3. Structure of this repo
  4. Implementation details
  5. Contributing
  6. References

1. Installation instructions

If you want to use the optimizer, you can download the repo from GitHub via git clone https://github.com/ltatzel/PyTorchHessianFree.git. Then, navigate to the project folder cd PyTorchHessianFree and install it with pip install -e .. Or install the package directly from GitHub via pip install hessianfree git+https://git@github.com/ltatzel/PyTorchHessianFree.git@main.

Additional requirements for the tests, examples and dev can be installed via pip install -e ".[tests]", pip install -e ".[examples]" or pip install -e ".[dev]", respectively. For running the tests, execute pytest from the repo's root directory.

2. Example

"""A minimal working example using the `HessianFree` optimizer on a small neural
network and some dummy data.
"""

import torch

from hessianfree.optimizer import HessianFree

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 16
DIM = 10

if __name__ == "__main__":

    # Set up model, loss-function and optimizer
    model = torch.nn.Sequential(
        torch.nn.Linear(DIM, DIM, bias=False),
        torch.nn.ReLU(),
        torch.nn.Linear(DIM, DIM),
    ).to(DEVICE)
    loss_function = torch.nn.MSELoss()
    opt = HessianFree(model.parameters(), verbose=True)

    # Training
    for step_idx in range(5):

        # Sample dummy data, define the `forward`-function
        inputs = torch.rand(BATCH_SIZE, DIM).to(DEVICE)
        targets = torch.rand(BATCH_SIZE, DIM).to(DEVICE)

        def forward():
            outputs = model(inputs)
            loss = loss_function(outputs, targets)
            return loss, outputs

        # Update the model's parameters
        opt.step(forward=forward)

3. Structure of this repo

The repo contains three folders:

4. Implementation details

5. Contributing

I would be very grateful for any feedback! If you have questions, a feature request, found a bug or have comments on how to improve the code, please don't hesitate to reach out to me.

6. References

[1] "Deep learning via Hessian-free optimization" by James Martens. In Proceedings of the 27th International Conference on International Conference on Machine Learning (ICML), 2010. Paper available at https://www.cs.toronto.edu/~jmartens/docs/Deep_HessianFree.pdf (accessed June 2022).

[2] "Training Deep and Recurrent Networks with Hessian-Free Optimization" by James Martens and Ilya Sutskever. Report available at https://www.cs.utoronto.ca/~jmartens/docs/HF_book_chapter.pdf (accessed June 2022).

[3] "BackPACK: Packing more into Backprop" by Felix Dangel, Frederik Kunstner and Philipp Hennig. In International Conference on Learning Representations,

  1. Paper available at https://openreview.net/forum?id=BJlrF24twB (accessed June 2022). Python package available at https://github.com/f-dangel/backpack.