Open theogf opened 4 years ago
I'm not certain this is what you are asking, but what you write looks a bit like you're using it as an alternative to gradient
, while I think it's meant to be used as an alternative to function application, i.e. instead of sum(x)
in this:
julia> using Flux, Zygote
julia> x = [1,2,3];
julia> g = gradient(() -> 7 + Zygote.forwarddiff(sum, x)^2, params(x))
Grads(...)
julia> g[x]
3-element Array{Int64,1}:
12
12
12
The function z -> 7 + z^2
is being handled backwards, and the function sum
forwards.
Sorry I think my example was a bit misleading. I don't aim at having a forward/reverse differentiation.
I meant that forwarddiff
cannot be used like gradient
with params
. Here is a better example I hope
foo() = sum(x)
x = [1,2,3]
Zygote.forwarddiff(()->foo(),params(x)) # This will not work
I'm still not sure I understand, but is this what you hope would return Grads(...)
?
gradient(() -> Zygote.forwarddiff(_ -> foo(), params(x)), params(x))
Otherwise of couse Zygote.forwarddiff(_ -> foo(), [0])
evaluates your function, but there is no AD involved. And the gradient of this is an error, because when it feeds in 0 + dual
it does not get a dual-number result.
At the stage of defining your function you can do this, there it has access to the input:
bar() = Zygote.forwarddiff(sum, x)
gradient(() -> bar(), params(x))
I'm still not sure I understand, but is this what you hope would return Grads(...)?
gradient(() -> Zygote.forwarddiff(_ -> foo(), params(x)), params(x))
Yes exactly! Sorry now I realized I misunderstood what forwarddiff
was doing.... Thanks a lot for the clarifications!
Then my question is : Is it conceivable to have a forwarddiff
function without giving it any arguments, or maybe to have a Zygote
-like ForwardDiff.jl?
By the way my question is related to https://discourse.julialang.org/t/zygote-adjoint-with-matrices/32188/4
In an ideal world, yes, a forwarddiff
with no arguments (just using Zygote's parameter list) is exactly what we'd do here. Unfortunately this depends on us having an SCT-style forward mode implementation which doesn't seem like it's going to happen any time soon.
Right now Zygote has the
forwarddiff
(via ForwardDiff.jl if I am not mistaking), but cannot support implicit parametersZygote.forwarddiff(()->foo(k),params(k))
will not work. I suppose it is due to the ForwardDiff backend but would there be a fundamental barrier to have a forward differentiation implementation that would support implicit parameters?