skorch-dev / skorch

A scikit-learn compatible neural network library that wraps PyTorch
BSD 3-Clause "New" or "Revised" License
5.86k stars 390 forks source link

Support for optimizers that require a closure #207

Closed tpietruszka closed 6 years ago

tpietruszka commented 6 years ago

Right now, the toy examples break if the torch.optim.LBFGS optimizer is used (likely also some others, e.g. conjugate gradients).

As per https://pytorch.org/docs/stable/optim.html#optimizer-step-closure it would seem that the closure way of calling the optimizer is preferred (and certainly required by those optimizers)

Right now, for my purposes, I've subclassed the skorch.NeuralNetRegressor class, overriding the train_step method with the following code. It seems to work, I'm just not sure if it interferes with some other skorch machinery.

    def train_step(self, Xi, yi, **fit_params):
        y_pred = None
        loss = None
        def train_closure():
            nonlocal y_pred
            nonlocal loss
            self.module_.train()
            self.optimizer_.zero_grad()
            y_pred = self.infer(Xi, **fit_params)
            loss = self.get_loss(y_pred, yi, X=Xi, training=True)
            loss.backward()
            self.notify(
            'on_grad_computed',
            named_parameters=list(self.module_.named_parameters())
            )
            return loss

        self.optimizer_.step(train_closure)
        return {
            'loss': loss,
            'y_pred': y_pred,
            }

Are there any downsides to this solution? Should I prepare a pull request?

BTW, it could be useful to disable batching, as e.g. LBFGS does not use them. There is the obvious workaround to set batch size to len(training set), but being able to pass None as batch_size would be clearer, I think

ottonemo commented 6 years ago

Thanks for reporting! It is a bit unfortunate that this will make the train step even more complicated than it is now but I don't see another way around it.

There might be one problem that I can think of when using nonlocal variables to communicate the results to the outside: there is no guarantee that the optimizer uses the results of the last called closure, it might use intermediate results for the optimization step. This would introduce a mismatch between returned loss/prediction and actually used loss/prediction between skorch and the optimizer.

For the loss the solution would be simple (loss = self.optimizer_.step(closure)) but this doesn't work for the prediction. Maybe we just have to live with it and hope that the optimizer doesn't do something unexpected.

It'd be great if you could create a pull request. Then we can discuss the details there.

tpietruszka commented 6 years ago

Good point!

Here is an idea: maybe use inspect to check if the optimizer's step() function requires a closure. If it does (no default argument), re-compute y_pred and loss after the call.

I think it is a correct solution and only causes performance costs when necessary. I do not particularly like it's style, but mismatches you've mentioned might be worse.

Also, should I prepare a pull request with the code and some tests? Or do you prefer to only use the core team's code? (asking since you've tagged this In progress)

(BTW, I wish I had discovered this library sooner, as I have wasted some time writing my own version, less complete of course. Really appreciate what you did here!)

Related notes below (optional reading)

After skimming through https://pytorch.org/docs/master/_modules/torch/optim/lbfgs.html#LBFGS.step I have several conclusions:

  1. if the optimizer terminates on any condition different than max_iter, there is a closure() call after the last state update - that would be the unexpected behavior you have mentioned.

Because of this I had an idea: if we compare the loss passed through nonlocal and the one returned by optimizer_.step, we should know whether or not the last closure call results were used. So we could only re-compute y_pred if necessary. (I would consider identical loss values for different y_pred unlikely... ). Its a little messy, but should work, except

  1. The loss returned by LBFGS.step() is the result of the first closure() call, not the last one. I think it's strange. Either I don't understand, or it is some sort of a bug. Of course it makes the approach described above invalid.
benjamin-work commented 6 years ago

Was solved in #247