mcabbott / Tullio.jl

β…€
MIT License
595 stars 26 forks source link

Gradient missing for `begin ... end` #90

Open roflmaostc opened 3 years ago

roflmaostc commented 3 years ago
julia> f1(x) = sum(@tullio res[i, j] := begin           
                       x[i+2, j] - 2 * x[i+1, j] + x[i, j]
                   end)
f1 (generic function with 1 method)

julia> gradient(f1, x)
ERROR: no gradient definition here!
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:33
 [2] (::Tullio.var"#215#216"{Tullio.Eval{var"#ℳ𝒢𝓀ℯ#11"{var"#π’œπ’Έπ“‰!#10"}, Nothing}, Tuple{Matrix{Float64}}, Matrix{Float64}})(Ξ”::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
   @ Tullio ~/.julia/packages/Tullio/bgqFi/src/grad/zygote.jl:7
 [3] (::Tullio.var"#85#back#217"{Tullio.var"#215#216"{Tullio.Eval{var"#ℳ𝒢𝓀ℯ#11"{var"#π’œπ’Έπ“‰!#10"}, Nothing}, Tuple{Matrix{Float64}}, Matrix{Float64}}})(Ξ”::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
   @ Tullio ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
 [4] Pullback
   @ ./REPL[39]:1 [inlined]
 [5] (::typeof(βˆ‚(f1)))(Ξ”::Float64)
   @ Zygote ~/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
 [6] (::Zygote.var"#41#42"{typeof(βˆ‚(f1))})(Ξ”::Float64)
   @ Zygote ~/.julia/packages/Zygote/KpME9/src/compiler/interface.jl:40
 [7] gradient(f::Function, args::Matrix{Float64})
   @ Zygote ~/.julia/packages/Zygote/KpME9/src/compiler/interface.jl:49
 [8] top-level scope
   @ REPL[40]:1

julia> f1(x) = sum(@tullio res[i, j] := x[i+2, j] - 2 * x[i+1, j] + x[i, j])
f1 (generic function with 1 method)

julia> gradient(f1, x)
([1.0 1.0 … 1.0 1.0; -1.0 -1.0 … -1.0 -1.0; … ; -1.0 -1.0 … -1.0 -1.0; 1.0 1.0 … 1.0 1.0],)

I encountered this because my equation got quite lengthy.

mcabbott commented 3 years ago

I suppose here it could sort this out, but in general when it sees begin ... end it concludes that you wanted an expression too complicated for the (very simple) symbolic differentiation here to handle. It can still handle them with dual numbers (and often this is no less efficient):

julia> using Zygote, ForwardDiff

julia> f1(x) = sum(@tullio res[i, j] := begin           
                       a = x[i+2, j] 
                       b = - 2 * x[i+1, j]
                       c = x[i, j]
                       a+b+c
                   end grad=Dual)
f1 (generic function with 1 method)

julia> gradient(f1, x)
([1.0 1.0 … 1.0 1.0; -1.0 -1.0 … -1.0 -1.0; … ; -1.0 -1.0 … -1.0 -1.0; 1.0 1.0 … 1.0 1.0],)