FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.53k stars 610 forks source link

Flux docs missing withgradient() call for multi-objective loss functions #2325

Closed jacob-m-wilson-42 closed 8 months ago

jacob-m-wilson-42 commented 1 year ago

After talking with @mcabbott on slack, we determined that Flux is missing documentation for a withgradient() call for multi-objective optimization. @mcabbott provided the below code

trio, grads = withgradient(model) do m
  a = loss_a(m, x)
  b = loss_b(m, x)
  (; c=a+b, a, b)
end

Documentation currently exists on Zygote.jl but not on Flux docs. The existing Zygote docs can be found here. As a suggestion, I think it might be appropriate to add these docs to the Loss Functions section and maybe add a new header that shows how to take gradients for multi-objective loss functions.

CarloLucibello commented 1 year ago

I'm not sure what we should add to the documentation. Not an expert on that, but in multi-objective optimization in principle you want to obtain a whole pareto frontier. I wouldn't mention the term "multi-objective" anywhere in the documentation. We could stress in more places that withgradient can return additional outputs if we think it deserves more prominence.

jacob-m-wilson-42 commented 1 year ago

There currently isn't any mention of how to look at multiple outputs using withgradient in the current Flux docs. We don't have to stress multi-objective optimization, but I think there definitely needs to be documentation in there for how to track multiple losses using withgradient. Currently, you have to take a look at Zygote, which I couldn't find until @mcabbot pointed it out for me.

ToucheSir commented 1 year ago

I guess where I'm confused is why this feature request is about multi-objective losses? Mostly because this:

loss, grads = withgradient(model) do m
  a = loss_a(m, x)
  b = loss_b(m, x)
  c = a + b
  return c
end

Is still training with a "multi-objective loss" from my perspective. What the recent withgradient change unlocked is the ability to return auxiliary state which will not have gradients backpropped through them. It just so happens that the example in the issue uses this mechanism to sneak out the individual losses for subsequent code, but it would be equally valid to use it for e.g. embeddings or non-differentiable metrics. Therefore, is this request more about documenting how to use auxiliary state with withgradient and Flux models with examples, or is it about showing how to optimize a model with multiple joint losses (and how similar that would be to what you'd do with other libraries, just return loss1 + loss2 + ...)?

mcabbott commented 1 year ago

https://github.com/FluxML/Flux.jl/pull/2331 suggests to add this bullet: https://github.com/FluxML/Flux.jl/pull/2331/files#diff-791e8b024a9ce7e7f89b45b7582d628d3d8d55f0bb5e17c39f8a50bd6aa21aeaR228-R230

jacob-m-wilson-42 commented 1 year ago

@ToucheSir, the functionality you included in your previous post is well-documented. That code will calculate the gradients and the total loss of the combined individual loss terms. I am interested in using withgradient to return the loss value of each individual loss term (in your code, I would like to track the individual loss a and b every epoch). It doesn't affect how the network trains, but it gives you some insight into how the network is performing on each individual loss term. I'm studying Physics Informed Neural Networks (PINNs) and it is important to see how the network is handling the initial, boundary and physics loss terms individually. Without the below functionality

trio, grads = withgradient(model) do m
  a = loss_a(m, x)
  b = loss_b(m, x)
  (; c=a+b, a, b)
end

you would have to call withgradient on loss component a and b individually and manually add the gradients back together before passing to the optimizer. I would like to see this syntax added to the Flux documentation as it is only currently documented in the Zygote package.

@mcabbott, yes even something small like that would really help! Just something to make it more explicit that this functionality exists.

ToucheSir commented 1 year ago

Ok, so in that case I agree with Carlo that the documentation should not mention multi-objective losses specifically, but rather focus on getting auxiliary information out and perhaps provide individual loss terms as an example.