Closed ChrisRackauckas closed 1 year ago
Here's a simplification:
using DiffEqFlux, Flux, NeuralPDE, ModelingToolkit, DomainSets
@parameters x y
@variables u(..)
Dxx = Differential(x)^2
Dyy = Differential(y)^2
# 2D PDE
eq = Dxx(u(x,y)) + Dyy(u(x,y)) ~ -sin(pi*x)*sin(pi*y)
# Initial and boundary conditions
bcs = [u(0,y) ~ 0.0, u(1,y) ~ -sin(pi*1)*sin(pi*y),
u(x,0) ~ 0.0, u(x,1) ~ -sin(pi*x)*sin(pi*1)]
# Space and time domains
domains = [x ∈ Interval(0.0,1.0),
y ∈ Interval(0.0,1.0)]
@named pde_system = PDESystem(eq,bcs,domains,[x,y],[u(x, y)])
fastchain = FastChain(FastDense(2,12,Flux.σ),FastDense(12,12,Flux.σ),FastDense(12,1))
fluxchain = Chain(Dense(2,12,Flux.σ),Dense(12,12,Flux.σ),Dense(12,1))
initθ = Float64.(DiffEqFlux.initial_params(fastchain))
grid_strategy = NeuralPDE.GridTraining(0.1)
discretization1 = NeuralPDE.PhysicsInformedNN(fastchain,
grid_strategy;
init_params = initθ)
discretization2 = NeuralPDE.PhysicsInformedNN(fluxchain,
grid_strategy;
init_params = initθ)
prob1 = NeuralPDE.discretize(pde_system,discretization1)
prob2 = NeuralPDE.discretize(pde_system,discretization2)
sym_prob = NeuralPDE.symbolic_discretize(pde_system,discretization1)
Zygote.gradient((x)->prob1.f(x,nothing),initθ)
Zygote.gradient((x)->prob2.f(x,nothing),initθ) # Very very different???
## Fixed
initθ = DiffEqFlux.initial_params(fastchain)
grid_strategy = NeuralPDE.GridTraining(0.1)
discretization1 = NeuralPDE.PhysicsInformedNN(fastchain,
grid_strategy;
init_params = initθ)
discretization2 = NeuralPDE.PhysicsInformedNN(fluxchain,
grid_strategy;
init_params = initθ,
phi = (x,p)->re(p)(x))
prob1 = NeuralPDE.discretize(pde_system,discretization1)
prob2 = NeuralPDE.discretize(pde_system,discretization2)
sym_prob = NeuralPDE.symbolic_discretize(pde_system,discretization1)
Zygote.gradient((x)->prob1.f(x,nothing),initθ)
Zygote.gradient((x)->prob2.f(x,nothing),initθ) # it's fine now!
Notice that it has incorrect gradients with initθ = Float64.(DiffEqFlux.initial_params(fastchain))
, as evidenced by the fact that one case is very different from the other 3 (FastChain Float32, Float64 and Chain Float32 all match, Chain Float64 is very different and the only one that doesn't train correctly).
I thought it would be a simple destructure
/restructure
bug with floating point types but isolating it failed:
using DiffEqFlux, Flux, Adapt
fastchain = FastChain(FastDense(2,12,Flux.σ),FastDense(12,12,Flux.σ),FastDense(12,1))
fluxchain = Chain(Dense(2,12,Flux.σ),Dense(12,12,Flux.σ),Dense(12,1))
initθ = DiffEqFlux.initial_params(fastchain)
p,re = Flux.destructure(fluxchain)
x = Float32[1.5,0.5]
dx1,dp1 = Zygote.gradient((x,p)->sum(fastchain(adapt(Array,x),p)),x,initθ)
dx2,dp2 = Zygote.gradient((x,p)->sum(re(p)(adapt(Array,x))),x,initθ)
dx1 ≈ dx2 # true
dp1 ≈ dp2 # true
initθ = Float64.(DiffEqFlux.initial_params(fastchain))
x = Float64[1.5,0.5]
dx3,dp3 = Zygote.gradient((x,p)->sum(fastchain(x,p)),x,initθ)
dx4,dp4 = Zygote.gradient((x,p)->sum(re(p)(x)),x,initθ)
dx3 ≈ dx1 # true
dx4 ≈ dx1 # true
dp3 ≈ dp1 # true
dp4 ≈ dp1 # true
But it goes away if I do f64
on the fluxchain:
using DiffEqFlux, Flux, NeuralPDE, ModelingToolkit, DomainSets
@parameters x y
@variables u(..)
Dxx = Differential(x)^2
Dyy = Differential(y)^2
# 2D PDE
eq = Dxx(u(x,y)) + Dyy(u(x,y)) ~ -sin(pi*x)*sin(pi*y)
# Initial and boundary conditions
bcs = [u(0,y) ~ 0.0, u(1,y) ~ -sin(pi*1)*sin(pi*y),
u(x,0) ~ 0.0, u(x,1) ~ -sin(pi*x)*sin(pi*1)]
# Space and time domains
domains = [x ∈ Interval(0.0,1.0),
y ∈ Interval(0.0,1.0)]
@named pde_system = PDESystem(eq,bcs,domains,[x,y],[u(x, y)])
fastchain = FastChain(FastDense(2,12,Flux.σ),FastDense(12,12,Flux.σ),FastDense(12,1))
fluxchain = Chain(Dense(2,12,Flux.σ),Dense(12,12,Flux.σ),Dense(12,1)) |> f64
initθ = Float64.(DiffEqFlux.initial_params(fastchain))
grid_strategy = NeuralPDE.GridTraining(0.1)
discretization1 = NeuralPDE.PhysicsInformedNN(fastchain,
grid_strategy;
init_params = initθ)
discretization2 = NeuralPDE.PhysicsInformedNN(fluxchain,
grid_strategy;
init_params = initθ)
prob1 = NeuralPDE.discretize(pde_system,discretization1)
prob2 = NeuralPDE.discretize(pde_system,discretization2)
Zygote.gradient((x)->prob1.f(x,nothing),initθ)
Zygote.gradient((x)->prob2.f(x,nothing),initθ) # it's fine now!
We worked around it here by just changing the number type ourselves so NeuralPDE is safe, but @CarloLucibello @mcabbott this is a pretty dangerous bug to have lurking around. Have you considered merging @DhairyaLGandhi 's branch https://github.com/FluxML/Optimisers.jl/tree/dg/noproject and adding a test to catch this in the future?
Have not tried to reproduce this, but this change https://github.com/FluxML/Optimisers.jl/commit/9c61c8a32a30c6d6565c102e41f11d4c98f3f22d looks like it ought to allow you to make a MWE, or at least to figure out what types are actually involved here. It does not look safe to merge.
My attempts at an MWE failed, but maybe @DhairyaLGandhi found a nicer one. I think you need a map right after the restructure or something, but 🤷 my isolations all worked, so it's something rather specific.
Can you tell me what this prints, and also post the stacktrace somewhere?
julia> using Optimisers
julia> @eval Optimisers begin
function _getat(y::AbstractArray, o::Int, flat::AbstractVector)
res = ProjectTo(y)(reshape(flat[o .+ (1:length(y))], axes(y)))
if eltype(res) != eltype(y)
@info "found one" summary(y) summmary(flat) summary(res)
end
res
end
end
_getat (generic function with 2 methods)
julia> using DiffEqFlux, Flux, NeuralPDE, ModelingToolkit, DomainSets
So far my attempt to install everything is a bit stuck... failed on 1.7 and 1.9.
That skipping precompilation is something that shows up with any packages on v1.7 if you update and reuse without restarting the REPL.
@DhairyaLGandhi identified the right spot, but his fix is incorrect. Here's a deterministic example
using DiffEqFlux, Flux, NeuralPDE, ModelingToolkit, DomainSets
@parameters x y
@variables u(..)
Dxx = Differential(x)^2
Dyy = Differential(y)^2
# 2D PDE
eq = Dxx(u(x,y)) + Dyy(u(x,y)) ~ -sin(pi*x)*sin(pi*y)
# Initial and boundary conditions
bcs = [u(0,y) ~ 0.0, u(1,y) ~ -sin(pi*1)*sin(pi*y),
u(x,0) ~ 0.0, u(x,1) ~ -sin(pi*x)*sin(pi*1)]
# Space and time domains
domains = [x ∈ Interval(0.0,1.0),
y ∈ Interval(0.0,1.0)]
@named pde_system = PDESystem(eq,bcs,domains,[x,y],[u(x, y)])
fastchain = FastChain(FastDense(2,12,Flux.σ),FastDense(12,12,Flux.σ),FastDense(12,1))
fluxchain = Chain(Dense(2,12,Flux.σ),Dense(12,12,Flux.σ),Dense(12,1))
initθ = range(0,1,length=205)
grid_strategy = NeuralPDE.GridTraining(0.1)
discretization1 = NeuralPDE.PhysicsInformedNN(fastchain,
grid_strategy;
init_params = initθ)
discretization2 = NeuralPDE.PhysicsInformedNN(fluxchain,
grid_strategy;
init_params = initθ)
prob1 = NeuralPDE.discretize(pde_system,discretization1)
prob2 = NeuralPDE.discretize(pde_system,discretization2)
Zygote.gradient((x)->prob1.f(x,nothing),initθ)
# ([0.34135572161301464, 0.4405596388580093, 0.5395470482221245, 0.6382976197739982, 0.7367915738790711, 0.8350087259178556, 0.9329294568383262, 1.030534260497729, 1.1278035963200221, 1.2247180249089877 … 98.49443032452626, 98.54203213497166, 98.58792677663844, 98.63218175029297, 98.6748692924491, 98.71603782132827, 98.75571771403575, 98.7939947935279, 98.8308816355171, 99.81405020016672],)
Zygote.gradient((x)->prob2.f(x,nothing),initθ)
# ([0.34197112172842026, -3.558639347553253, 4.5405376851558685, 0.6394826173782349, 0.7381760030984879, -3.1634015142917633, -7.0652690678834915, 17.032554239034653, -6.869950324296951, -14.772801548242569 … -413.50527000427246, 98.54232025146484, 610.5882167816162, 98.63247489929199, 98.6751537322998, 98.71630668640137, 610.7559909820557, 610.7942523956299, 98.83114624023438, 99.81404876708984],)
## Fixed
initθ = Float32.(range(0,1,length=205))
grid_strategy = NeuralPDE.GridTraining(0.1)
discretization1 = NeuralPDE.PhysicsInformedNN(fastchain,
grid_strategy;
init_params = initθ)
discretization2 = NeuralPDE.PhysicsInformedNN(fluxchain,
grid_strategy;
init_params = initθ)
prob1 = NeuralPDE.discretize(pde_system,discretization1)
prob2 = NeuralPDE.discretize(pde_system,discretization2)
Zygote.gradient((x)->prob1.f(x,nothing),initθ)
# (Float32[0.34133404, 0.440552, 0.5395764, 0.63829243, 0.73679507, 0.83497345, 0.9328923, 1.0304716, 1.1278142, 1.2247037 … 98.49864, 98.546234, 98.59311, 98.63931, 98.6693, 98.71631, 98.76674, 98.80206, 98.82138, 99.81404],)
Zygote.gradient((x)->prob2.f(x,nothing),initθ)
# (Float32[0.3413074, 0.44055194, 0.5395802, 0.6383001, 0.7367189, 0.83488965, 0.9327855, 1.0305097, 1.127738, 1.2246275 … 98.49864, 98.546234, 98.59311, 98.63931, 98.6693, 98.71631, 98.76674, 98.80206, 98.82138, 99.81404],)
# Doesn't do anything:
@eval Optimisers begin
function _getat(y::AbstractArray, o::Int, flat::AbstractVector)
res = ProjectTo(y)(reshape(flat[o .+ (1:length(y))], axes(y)))
if eltype(res) != eltype(y)
@info "found one" summary(y) summmary(flat) summary(res)
end
res
end
end
initθ = range(0,1,length=205)
grid_strategy = NeuralPDE.GridTraining(0.1)
discretization1 = NeuralPDE.PhysicsInformedNN(fastchain,
grid_strategy;
init_params = initθ)
discretization2 = NeuralPDE.PhysicsInformedNN(fluxchain,
grid_strategy;
init_params = initθ)
prob1 = NeuralPDE.discretize(pde_system,discretization1)
prob2 = NeuralPDE.discretize(pde_system,discretization2)
Zygote.gradient((x)->prob1.f(x,nothing),initθ)
# ([0.34135572161301464, 0.4405596388580093, 0.5395470482221245, 0.6382976197739982, 0.7367915738790711, 0.8350087259178556, 0.9329294568383262, 1.030534260497729, 1.1278035963200221, 1.2247180249089877 … 98.49443032452626, 98.54203213497166, 98.58792677663844, 98.63218175029297, 98.6748692924491, 98.71603782132827, 98.75571771403575, 98.7939947935279, 98.8308816355171, 99.81405020016672],)
Zygote.gradient((x)->prob2.f(x,nothing),initθ)
# ([0.34197112172842026, -3.558639347553253, 4.5405376851558685, 0.6394826173782349, 0.7381760030984879, -3.1634015142917633, -7.0652690678834915, 17.032554239034653, -6.869950324296951, -14.772801548242569 … -413.50527000427246, 98.54232025146484, 610.5882167816162, 98.63247489929199, 98.6751537322998, 98.71630668640137, 610.7559909820557, 610.7942523956299, 98.83114624023438, 99.81404876708984],)
# Doesn't do anything
using Optimisers
@eval Optimisers begin
_getat(y::AbstractArray{T}, o::Int, flat::AbstractVector) where T =
T.(reshape(flat[o .+ (1:length(y))], axes(y))) # ProjectTo is just correcting eltypes
end
Zygote.gradient((x)->prob1.f(x,nothing),initθ)
# ([0.34135572161301464, 0.4405596388580093, 0.5395470482221245, 0.6382976197739982, 0.7367915738790711, 0.8350087259178556, 0.9329294568383262, 1.030534260497729, 1.1278035963200221, 1.2247180249089877 … 98.49443032452626, 98.54203213497166, 98.58792677663844, 98.63218175029297, 98.6748692924491, 98.71603782132827, 98.75571771403575, 98.7939947935279, 98.8308816355171, 99.81405020016672],)
Zygote.gradient((x)->prob2.f(x,nothing),initθ)
# ([0.34197112172842026, -3.558639347553253, 4.5405376851558685, 0.6394826173782349, 0.7381760030984879, -3.1634015142917633, -7.0652690678834915, 17.032554239034653, -6.869950324296951, -14.772801548242569 … -413.50527000427246, 98.54232025146484, 610.5882167816162, 98.63247489929199, 98.6751537322998, 98.71630668640137, 610.7559909820557, 610.7942523956299, 98.83114624023438, 99.81404876708984],)
# Fixed!!! ?
using Optimisers
@eval Optimisers begin
_getat(y::AbstractArray{T}, o::Int, flat::AbstractVector) where T =
@show Float64.(reshape(flat[o .+ (1:length(y))], axes(y))) # ProjectTo is just correcting eltypes
end
Zygote.gradient((x)->prob1.f(x,nothing),initθ)
# ([0.34135572161301464, 0.4405596388580093, 0.5395470482221245, 0.6382976197739982, 0.7367915738790711, 0.8350087259178556, 0.9329294568383262, 1.030534260497729, 1.1278035963200221, 1.2247180249089877 … 98.49443032452626, 98.54203213497166, 98.58792677663844, 98.63218175029297, 98.6748692924491, 98.71603782132827, 98.75571771403575, 98.7939947935279, 98.8308816355171, 99.81405020016672],)
Zygote.gradient((x)->prob2.f(x,nothing),initθ)
# ([0.34135578866824007, 0.44055960533039673, 0.539547040771544, 0.6382977315327074, 0.736791544076749, 0.8350087706213394, 0.932929464288907, 1.030534171090762, 1.127803417506088, 1.2247179504031824 … 98.49443127820058, 98.54203022762303, 98.58792582296412, 98.63217888927002, 98.67485975570594, 98.7160397286769, 98.75572534343029, 98.79398334943609, 98.83090070900344, 99.81405020016672],)
using Optimisers
@eval Optimisers begin
_getat(y::AbstractArray{T}, o::Int, flat::AbstractVector) where T =
reshape(flat[o .+ (1:length(y))], axes(y))
end
Zygote.gradient((x)->prob1.f(x,nothing),initθ)
# ([0.34135572161301464, 0.4405596388580093, 0.5395470482221245, 0.6382976197739982, 0.7367915738790711, 0.8350087259178556, 0.9329294568383262, 1.030534260497729, 1.1278035963200221, 1.2247180249089877 … 98.49443032452626, 98.54203213497166, 98.58792677663844, 98.63218175029297, 98.6748692924491, 98.71603782132827, 98.75571771403575, 98.7939947935279, 98.8308816355171, 99.81405020016672],)
Zygote.gradient((x)->prob2.f(x,nothing),initθ)
# ([0.34135578866824007, 0.44055960533039673, 0.539547040771544, 0.6382977315327074, 0.736791544076749, 0.8350087706213394, 0.932929464288907, 1.030534171090762, 1.127803417506088, 1.2247179504031824 … 98.49443127820058, 98.54203022762303, 98.58792582296412, 98.63217888927002, 98.67485975570594, 98.7160397286769, 98.75572534343029, 98.79398334943609, 98.83090070900344, 99.81405020016672],)
using Optimisers
@eval Optimisers begin
function _getat(y::AbstractArray{T}, o::Int, flat::AbstractVector) where T
@show eltype(y), eltype(flat)
reshape(flat[o .+ (1:length(y))], axes(y))
end
end
Zygote.gradient((x)->prob1.f(x,nothing),initθ)
# ([0.34135572161301464, 0.4405596388580093, 0.5395470482221245, 0.6382976197739982, 0.7367915738790711, 0.8350087259178556, 0.9329294568383262, 1.030534260497729, 1.1278035963200221, 1.2247180249089877 … 98.49443032452626, 98.54203213497166, 98.58792677663844, 98.63218175029297, 98.6748692924491, 98.71603782132827, 98.75571771403575, 98.7939947935279, 98.8308816355171, 99.81405020016672],)
Zygote.gradient((x)->prob2.f(x,nothing),initθ)
# (eltype(y), eltype(flat)) = (Float32, Float64)
# ([0.34135578866824007, 0.44055960533039673, 0.539547040771544, 0.6382977315327074, 0.736791544076749, 0.8350087706213394, 0.932929464288907, 1.030534171090762, 1.127803417506088, 1.2247179504031824 … 98.49443127820058, 98.54203022762303, 98.58792582296412, 98.63217888927002, 98.67485975570594, 98.7160397286769, 98.75572534343029, 98.79398334943609, 98.83090070900344, 99.81405020016672],)
Instead of going down to Float32
, it needs to widen to Float64 to be correct.
What's the reason for that ProjectTo(y)
? Maybe it needs to have a promote_type
in there?
# (eltype(y), eltype(flat)) = (Float32, Float64)
This is what's expected when the primal is Float32 for this variable y
, but Float64 for others. The flat
vector has to widen, but ProjectTo(y)
ensures that the reconstructed gradient for y
is correctly made Float32.
If something else wants a Float64 gradient for a Float32 variable, then maybe that's the problem.
I understand why it exists now, but I don't understand why the type of y
is considered the end-all be-all here. The parameters inside the neural network are never used in this example, the NN is only used for its structure and restructured with new values. Yet, the type of the encoded values in there silently will cause the precision of the solution to change, even though the user tries to avoid those values ever existing. The only place where those values creep in happens to be in one part of the backwards pass, where they end up being used for a type conversion, so even though the values are never used you have to be careful about the type. I can't be the only one that sees that as a weird action at a distance?
Basically, why wouldn't re(p)
take on the element type of p
and would instead preserve the types it previously had? This also explains some of the issues with GPUs then, because re(p)
where the NN is CPU-based and p
is GPU-based doesn't build a GPU-based version of the NN. This also explains some of the issues with TrackedArrays. Etc. This explains a lot of the bugs with now, but I just don't understand why that needs to exist.
Basically, why wouldn't re(p) take on the element type of p and would instead preserve the types it previously had?
Because its one job is reconstruction? It's explicitly designed to allow for mixed precisions, and not forget this. And not just precisions, the help example is:
julia> v, re = destructure((x=[1.0, 2.0], y=(sin, [3 + 4im])))
(ComplexF64[1.0 + 0.0im, 2.0 + 0.0im, 3.0 + 4.0im], Restructure(NamedTuple, ..., 3))
julia> re([3, 5-im, 7+11im])
(x = [3.0, 5.0], y = (sin, ComplexF64[7.0 + 11.0im]))
This also explains some of the issues with GPUs then, because re(p) where the NN is CPU-based and p is GPU-based
No, it does not. destructure
on something containing a mix of GPU and CPU arrays is essentially undefined behaviour. (IIRC it depends on the order of arrays.) I'm not sure I follow what you think the behaviour should be, but it could be made to be something. Make an issue if this case is useful.
This also explains some of the issues with TrackedArrays. Etc.
Does it? What issues?
Because its one job is reconstruction?
Well, that's why I'm confused why it's doing more than just reconstruction. It's not just reconstructing the array p
into the neural network architecture, it's also changing its values to not match p
, so it's not a reconstruction of p
into the form of the destructured thing. I expected
julia> re([3, 5-im, 7+11im])
(x = [3+0*im, 5-im], y = (sin, [7 + 11im]))
i.e. it would be "the same form" but with the values of p
. In fact, your example there shows something very scary because it changed 5-im
to 5.0
: those aren't just different types, but different values because 5-im != 5.0
! Also, it doesn't seem to be very consistent. With x
it tries to convert everything back to Float64, even though the values are not representable as Float64s. But with y
, it's perfectly happy with ComplexF64[7.0 + 11.0im]
instead of changing it back to [7 + 11im]
when the original values were in Complex{Int}
? So there's no guarantee that the values are the same as p
, and there's no guarantee that the types match those of the destructured
thing, and there's no guarantee that it matches the types of p
! ComplexF64
shows up as the input to none of the functions, but still shows up in the result. What is the rule then? I honestly would prefer to just get an error in this case then, as this is definitely not what I expected.
No, it does not. destructure on something containing a mix of GPU and CPU arrays is essentially undefined behaviour. (IIRC it depends on the order of arrays.) I'm not sure I follow what you think the behaviour should be, but it could be made to be something. Make an issue if this case is useful.
No, I'm saying I would've expected:
julia> re(cu([3, 5-im, 7+11im]))
(x = cu([3+0*im, 5-im]), y = (sin, cu([7 + 11im])))
"It's the same as the destructured thing, but with the values taken from p
, x
is just p[1:2]
and y is just (sin,p[3:3])
" is the rule. Simple, straightforward, and always matches the values of p
. This also would make it support TrackedArray
, since then
julia> re(Tracked([3, 5-im, 7+11im]))
(x = Tracked([3+0*im, 5-im]), y = (sin, Tracked([7 + 11im])))
But anyways, now I'm worried about that complex case: that should definitely be counted as a bug IMO, or throw an error.
Right, p
is the actual source of truth in the case of reconstruction, and my fix was only showing that we don't respect that right now. This is currently only doing basic eltype
conversion which comes at the cost of extra copies. I didn't push that as a PR since we need to be certain that we want to let Julia figure out the types and reconstruct the model with p
as the source of truth.
It is actually a related issue issue as why we expect some custom leaf types to return structures as adjoints in the backpass and therefore need to reconstruct the type as opposed to operating on the fields directly. Complex
is a case of that behaviour since its interaction with Number
s is specialised. In this case, I would much rather get a MethodError
saying that there is no accum(::MyType, ::NamedTuple{fieldnames(MyType)})
or better yet give us a method we can use to reconstruct back the type from a primal and its gradient.
The fundamental issue here is that destructure
has been pulled in two mutually incompatible directions by two disparate use cases:
Chris has already described why no type conversion during reconstruction makes sense for 1). For 2), I recall we went down this path after many issues where users were expecting the following invariants to be upheld:
a. p, re = destructure(m); re(p)
should be the identity operation.
b. gradient(m -> loss(m, ...), m) |> flatten == gradient(p -> loss(re(p), ...), p)
The complex example is actually a great one because it shows how these can be broken without type conversion. If you pass in a Dense(weight=Float[], bias=Complex[])
, you get out a Dense(weight=Complex[], bias=Complex[])
and break a). This means your gradient will be a NamedTuple(weight=Complex[], bias=Complex[])
. If the imaginary component of weight
's gradient is non-zero, then woops, the training trajectories will diverge and b) will no longer hold.
I think the only way to resolve this tension is to bifurcate the destructure
interface. Perhaps it would make sense to still keep the same name for both functions, but I don't see a way to support both use cases using the same code path.
For use as 1, is it really too much to ask that you make the "template" model with the desired number type? That seems like a simpler, easier-to-understand API. Rather than having some special additional mode to know about, document & test.
At present it re
from a CPU model will work happily on a GPU v
, or the reverse -- the location is not stored. (This isn't really by design, it just falls out of re-using ProjectTo
for convenience, which never checks that gradient has the same storage location.)
For use as 1, is it really too much to ask that you make the "template" model with the desired number type?
If that's the case it should probably throw an error instead of computing incorrect gradients. The complex number case would be a particularly nasty one to try and debug. Even finding this behavior took a long time.
But there are no incorrect gradients here. Like everything else using ChainRules, these are dy_final = ProjectTo(y)(dy_raw)
. Complex numbers were literally the original motivating case for introducing such projection operators. And allowing Float64 gradients for Float32 variables was, for a long time, the number 1 way to accidentally get awful performance from Flux.
Your complaint is, if I understand right, entirely that _, re = destructure(m); re(new_p)
reconstructs a model with the same element types as m
, rather than always following new_p
. I'm sorry if this was very surprising, but it's now clearly documented: https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.destructure
At present it re from a CPU model will work happily on a GPU v, or the reverse -- the location is not stored. (This isn't really by design, it just falls out of re-using ProjectTo for convenience, which never checks that gradient has the same storage location.)
And that makes it surprising. re(::CuArray)
gives a GPU-based version, so re(::Array{Complex})
gives a Float32 version. It would at least be easier to understand if there was consistency here. If it's either always obey p
or always obey the functor then either way it's at least easy to guess what the restructured object would do. I thought it was "always obey p
" because of how it acted with GPUs.
Chris has already described why no type conversion during reconstruction makes sense for 1). For 2), I recall we went down
this path after many issues where users were expecting the following invariants to be upheld: a. p, re = destructure(m); re(p) should be the identity operation. b. gradient(m -> loss(m, ...), m) |> flatten == gradient(p -> loss(re(p), ...), p)
This whole discussion has been about 2. The issue is that type conversion presents itself as incorrect gradients in the case of (2). A calculation which says "I want to use Complex{Float64}" will silently use Float32, returning 0's for the complex values and computing with incorrect precision on the real parts. The only way this is exposed to the user if one checks the gradient calculation (something that isn't a user level property anyways, so it's actually just hidden as "it didn't train").
One issue here is that it only even represents itself as existing in the forward pass in isolation. Here we do things like u0 .+ re(p)(u)
, and so if you look in any diagnostic function you see complex in -> complex out, things look "fine", but only because you didn't check that aha re(p)(u)
actually downconverted p
from complex to real, and then it upconverted in the next operation. The way it would expose itself is again only in the adjoints because you'd see + 0im
everywhere.
Look back at https://github.com/FluxML/Flux.jl/pull/1901#issuecomment-1065788096 . Now that I've finally isolated this 5 months later, I realized that this behavior change is what caused the downstream tests to fail, sometimes, depending on the Optimisers version that was received. The precision change caused there to be a higher probability for test failure (since it was still random initializations), so the tests actually found it, but run enough times and the last one was green, it looked like a fluke. Almost imperceptibly the behavior just was "things are a little bit more janky these days, nobody really knows why" until I finally got it isolated as just that the gradient precision was different from the precision that was specified. It might now be clearly documented, but this is very easy to accidentally hit, and very hard to diagnose unless you already know it's possible. Multiple people took a look and no one realized that passing Float64's around isn't a good idea if you forget |> f64
.
a. p, re = destructure(m); re(p) should be the identity operation.
I don't see why that should be the case. p
is an array so it will type promote. If you then re(p)
then you'll get the operations of whatever p
is, which would be the promoted versions (or at least, that's how I thought it worked).
I thought it was "always obey p" because of how it acted with GPUs.
Great, well now that the documentation is clear, no need to guess.
type conversion presents itself as incorrect gradients
Again, no incorrect gradients have been exhibited here.
which says "I want to use Complex{Float64}"
The way you say this is by making the primal complex. Real numbers may not have complex gradients. The answer to "which way is uphill from the summit of this mountain?" cannot sensibly be a complex vector. Allowing that was a source of truly mysterious gradient bugs.
behavior change is what caused the downstream tests to fail
As you know, the old destructure
was cobbled together in 5 minutes, had approximately zero tests, for all of its years, and had many, many known bugs. It's inevitable that, sadly, some code relied on such bugs.
surprising. re(::CuArray) gives a GPU-based version
If you think this ought to be yet more strict, please make an issue.
Continuing this discussion upstream.
MWE:
See fluxchain fails and the gradient is off.