FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.46k stars 209 forks source link

Implicit parameters for forwarddiff #424

Open theogf opened 4 years ago

theogf commented 4 years ago

Right now Zygote has the forwarddiff (via ForwardDiff.jl if I am not mistaking), but cannot support implicit parameters Zygote.forwarddiff(()->foo(k),params(k)) will not work. I suppose it is due to the ForwardDiff backend but would there be a fundamental barrier to have a forward differentiation implementation that would support implicit parameters?

mcabbott commented 4 years ago

I'm not certain this is what you are asking, but what you write looks a bit like you're using it as an alternative to gradient, while I think it's meant to be used as an alternative to function application, i.e. instead of sum(x) in this:

julia> using Flux, Zygote

julia> x = [1,2,3];

julia> g = gradient(() -> 7 + Zygote.forwarddiff(sum, x)^2, params(x))
Grads(...)

julia> g[x]
3-element Array{Int64,1}:
 12
 12
 12

The function z -> 7 + z^2 is being handled backwards, and the function sum forwards.

theogf commented 4 years ago

Sorry I think my example was a bit misleading. I don't aim at having a forward/reverse differentiation. I meant that forwarddiff cannot be used like gradient with params. Here is a better example I hope

foo() = sum(x)
x = [1,2,3]
Zygote.forwarddiff(()->foo(),params(x)) # This will not work
mcabbott commented 4 years ago

I'm still not sure I understand, but is this what you hope would return Grads(...)?

gradient(() -> Zygote.forwarddiff(_ -> foo(), params(x)), params(x))

Otherwise of couse Zygote.forwarddiff(_ -> foo(), [0]) evaluates your function, but there is no AD involved. And the gradient of this is an error, because when it feeds in 0 + dual it does not get a dual-number result.

At the stage of defining your function you can do this, there it has access to the input:

bar() = Zygote.forwarddiff(sum, x)
gradient(() -> bar(), params(x))
theogf commented 4 years ago

I'm still not sure I understand, but is this what you hope would return Grads(...)?

gradient(() -> Zygote.forwarddiff(_ -> foo(), params(x)), params(x))

Yes exactly! Sorry now I realized I misunderstood what forwarddiff was doing.... Thanks a lot for the clarifications!

Then my question is : Is it conceivable to have a forwarddiff function without giving it any arguments, or maybe to have a Zygote-like ForwardDiff.jl?

By the way my question is related to https://discourse.julialang.org/t/zygote-adjoint-with-matrices/32188/4

MikeInnes commented 4 years ago

In an ideal world, yes, a forwarddiff with no arguments (just using Zygote's parameter list) is exactly what we'd do here. Unfortunately this depends on us having an SCT-style forward mode implementation which doesn't seem like it's going to happen any time soon.