FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.46k stars 603 forks source link

Rethink train design and better callbacks support #1461

Open DhairyaLGandhi opened 3 years ago

DhairyaLGandhi commented 3 years ago

There are a few cases where I find myself wondering if we should make it more explicit how we can extend the train loop design to be more friendly for callbacks not having to cheat to get things like the loss and so on. Further, things like FluxTraining.jl also show that we have a certain lack of preexisting callbacks, which don't need to be rewritten.

So keeping this stuff in mind, I think using pullback instead of gradient would be a step towards that, as well as not optimising before a prehook to check for callback conditions etc. This should also fall in nicely how we want to set up schedulers. I would also want to figure out where distributed and multi gpu falls in this, so we know how to proceed.

We don't necessarily want to return the losses etc, but perhaps a slightly more trained model? This would fall in line with how Optimisers.jl is looking as well.

xref #1017 #1067 #1434

cc @lorenzoh

CarloLucibello commented 3 years ago

Here I totally second @ToucheSir 's opinion expressed on Discourse:

I’m of the somewhat extreme opinion that train! should be removed outright and Flux should only expose gradient/pullback. The one-function API looks nice on the surface, but it is horribly inflexible and obscures a lot of common errors that crop up in training loops. PyTorch seems to do just fine without one too, so it’s not like removing it will cause an unmitigated UX disaster.

We should deprecate train! entirely and let higher level packages such as FluxTraining.jl handle this

DhairyaLGandhi commented 3 years ago

The loop is meant to make training easier and we have seen it being useful in many cases. I am with you in wanting to do away with it, but also want to make sure the rest of the tooling is available, and actually make it more powerful without sacrificing on performance. I don't think there is an ideal loop, but one that can hold most cases, and have the tooling to extend and offer a guiding principle around best practices related to various bits of training is definitely worth it.

ToucheSir commented 3 years ago

I don't mind having something like train! around, but perhaps the question is how many people are using it because that's all they need for their model, and how many are using it because it looks like what you should be using. For example, Keras has fit, so train! must == fit and thus I should use only train! (with the implication that custom loops are scary and for advanced users only).

I don't think there is an ideal loop, but one that can hold most cases, and have the tooling to extend and offer a guiding principle around best practices related to various bits of training is definitely worth it.

I agree with this, but "hold most cases" is a slippery slope into something like https://github.com/lorenzoh/FluxTraining.jl/blob/master/src/train.jl. Again, there's the question of intent/scope. If Flux is fine with being an opinionated, multilevel framework, then something like train! makes perfect sense. If Flux is looking to be less opinionated on higher-level APIs or eschew them altogether, then train! sticks out like a sore thumb.

atiyo commented 3 years ago

While there might not be a single ideal loop, I think it should be possible to have customisation that fits in nicely with train! while maintaining easy to use defaults.

Frameworks like PyTorchLightning are quite high level, but allow for custom training loops, for example.

For something similar in Flux, we could introduce a new method for train! that accepts a function which describes a single training step.

I have no strong feelings about above, but thought I would raise it since PyTorchLightning's abstractions to training loops seem to have attracted some fans. The fact that PyTorchLightning's abstractions are not native to Pytorch might indicate the value in having separate high and low level libraries.

ToucheSir commented 3 years ago

That's what something like FluxTraining is doing already. You may want to have a look at the extensive back-and-forths we've had about training loop design on Zulip.

darsnack commented 3 years ago

I agree with @CarloLucibello here. There are so many ways to design a training loop that it's going to be impossible to handle every case. train! really serves as a pedagogical example for how each piece of the training iteration comes together in Flux.

FluxTraining itself relies on other packages for pieces of the full workflow. Trying to put even a simple portion of that into Flux seems like asking for maintenance overheads that we can't service. It also doesn't add much value to the Flux ecosystem. Julia users have no qualms about installing Flux + FluxTraining (or any other high level package). Multiple packages with the correct abstractions is how our ecosystem works.

lorenzoh commented 3 years ago

I also think that this doesn't need to be something Flux.jl handles. I would rather have a full-fledged solution and I think that is out-of-scope for Flux.jl itself considering the complexity of FluxTraining.jl.

Multiple packages with the correct abstractions is how our ecosystem works.

+1 this, composable packages are the way to go where possible.

CarloLucibello commented 3 years ago

We could adopt a gentle deprecation path since FluxTraining.jl is not ready for debut (or it is?): remove train! from docs and model-zoo's examples for Flux v0.12, deprecate it in v0.13, and remove it later

DhairyaLGandhi commented 3 years ago

I definitely think that adding the correct abstractions is an important bit. FluxTraining.jl is a very opinionated package in terms of training routines, so it's harder to justify it as a catch all. It's flexibility imo should come from making the callbacks consistent and available more easily to be used directly with the same kind of semantics as Flux. I feel there is benefit to having the train function in, because it's describing the semantics we expect, and is sufficient for most models, but we need to message it appropriately to suggest that it might be used in multiple ways, or that the for loop is a first class api that may be preferred for different packages and training routines and hit up examples showing it.

DhairyaLGandhi commented 3 years ago

This might mean that we flesh out the docs or the function and point to more directly catered packages in the ecosystem. I don't see how that takes away from the composable nature of the Julia ecosystem, but formalizes how we have built the abstractions so far

lorenzoh commented 3 years ago

Regarding FluxTraining.jl: if you want to do standard supervised learning it already works great and has a similar feature set to fastai's training loop (barring mixed-precision training and advanced data augmentation schemes).

It is also possible to write custom training loops for things like GAN training, though not always elegantly due to how the state is set up. So there is some more design work to be done to make it possible to support other training paradigms cleanly. I haven't yet since I am doing pretty basic supervised training in my current work; maybe someone who more actively works with other training paradigms like GANs and self-supervised learning can weigh in on what is missing FluxTraining.jl to support those use cases.

darsnack commented 3 years ago

the for loop is a first class api

This is something that I completely agree with.

This might mean that we flesh out the docs or the function and point to more directly catered packages in the ecosystem

More complete examples with for loops would be a good idea.

train! is not a for loop. It is a loop internally, but the interface exposed to users is a single function. This is why there is a need for callbacks at all. The user has no access to the for loop, so callbacks are necessary to allow the user an entry point.

making the callbacks consistent and available more easily to be used directly with the same kind of semantics as Flux

The implementation of a callback in simple for loop semantics is a function call. Doing anything more complicated would only make the for loop less simple.

More documentation for how users can write their own train! loops seems like the answer here instead of designing a callback system.

lorenzoh commented 3 years ago

More documentation for how users can write their own train! loops seems like the answer here instead of designing a callback system.

Agreed, train! is only a small improvement API-wise but a large restriction in extensibility compared to writing a simple for-loop

DhairyaLGandhi commented 3 years ago

train! is not a for loop.

That is exactly correct and thanks for bringing it up. This also harkens back to https://github.com/FluxML/Flux.jl/issues/1461#issue-784158849 where part of the goal would be to expose functionality within this loop. Be that through pre/posthooks or scheduling points to give control to user written code in other ways.

Doing anything more complicated would only make the for loop less simple

Yes, and this thread is to weigh in those schemes. I don't think having stubs would necessitate complicating the function to any meaningful degree, as long as the objective is to let tinkering with the loop possible.

large restriction in extensibility compared to writing a simple for-loop

Since the question here is to see how to make extensible designs for the train function, I think this is subsumed?

DhairyaLGandhi commented 3 years ago

I was thinking of something like this to expose the loop to the users. Users can add containers to hold some params, and allow for arbitrary code to run before the optimisation step, and after

struct Callback{T}
  losses::AbstractVector{T}
end

# l: loss at the datapoint
# ps: params (maybe can skip but good to have to avoid globals)
# gs: grads at the datapoint to inspect
# d: datapoint
# opt: modify optimiser based on some condition

(cb::Callback)(l, ps, gs, d, opt) = append!(cb.losses, l)

prehook(l, ps, gs, d, opt) = throw(Flux.Optimise.SkipException())

c = Callback(Float32[])
Flux.train!(loss, ps, data, opt, cb = [() -> (), c])
Flux.train!(loss, ps, data, opt, prehooks = prehook)