hjmshi / PyTorch-LBFGS

A PyTorch implementation of L-BFGS.
MIT License
580 stars 66 forks source link

Trying to accelerate LBFGS #13

Open muammar opened 5 years ago

muammar commented 5 years ago

Thanks for this implementation. Recently, I've been working on a package to use machine learning for chemistry problems where I use pytorch to train some models. I have been able to perform distributed training using a library called dask that accelerated the training phase.

When I use first-order optimization algorithms such as Adam, I can get up to 3 optimization steps per second (but those algorithms converge slowly compared to second-order ones). When using LBFGS I just get 1 optimization step each 7 seconds for the same number of parameters. I am interested in using a dask client to make some parts of your LBFGS implementation to work in a distributed manner so that each optimization step is faster. I started reading the code, and have a very brief idea of the LBFGS algorithm. However, I wondered if you could give me some hints about the parts of the module that could be independently computed and therefore distributed?

I would appreciate your thoughts on this.

hjmshi commented 5 years ago

Thanks for your question! I'm not too familiar with dask and am not sure about your problem setting. Can you clarify what problem you are looking at? Is it a finite-sum problem?

Typically, SGD/Adam are distributed in a data-parallel fashion, where only the function/gradient over a subset of the dataset is evaluated over each node, then aggregated for computation. Something similar can be done for L-BFGS, although there are various possible approaches for dealing with the two-loop recursion, line search, etc. However, this approach makes sense only if function/gradient evaluations are the primary bottleneck in computation (as it is in deep learning). If you have some additional details about your problem, I may be able to give some better ideas for distributing the algorithm.