FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 209 forks source link

Size mismatch between implicit parameters and their gradients #1015

Open JordiBolibar opened 3 years ago

JordiBolibar commented 3 years ago

I'm encountering a dimensions mismatch between the implicit parameters of a Flux Chain and the gradients obtained from a pullback. This code is part of a Universal Differential Equation combining a 2D PDE with a NN (UA). The difference occurs in the first layer:

# Leaky ReLu as activation function
leakyrelu(x, a=0.01) = max(a*x, x)

# Constrains A within physically plausible values
relu_A(x) = min(max(1.58e-17, x), 1.58e-16)

# Define the networks 1->10->5->1
UA = Chain(
    Dense(1,10,initb = Flux.zeros), 
    BatchNorm(10, leakyrelu),
    Dense(10,5,initb = Flux.zeros), 
    BatchNorm(5, leakyrelu),
    Dense(5,1, relu_A, initb = Flux.zeros) 
)

loss_UA, back_UA = Zygote.pullback(() -> loss(H, UA, p, t, t₁), ps_UA)

for ps in ps_UA
    @show ps, ∇_UA[ps]
    println("size ps: ", size(ps))
    println("size ∇_UA[p]: ", size(∇_UA[ps]))
end
(ps, ∇_UA[ps]) = (Float32[0.30798444; 0.52923024; -0.38974404; -0.44488695; -0.007705779; -0.46781763; 0.30756167; -0.73545146; -0.60995233; -0.12095741], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
size ps: (10, 1)
size ∇_UA[p]: (10,)
(ps, ∇_UA[ps]) = (Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
size ps: (10,)
size ∇_UA[p]: (10,)
(ps, ∇_UA[ps]) = (Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
size ps: (10,)
size ∇_UA[p]: (10,)
(ps, ∇_UA[ps]) = (Float32[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
size ps: (10,)
size ∇_UA[p]: (10,)
(ps, ∇_UA[ps]) = (Float32[-0.49900624 0.1739867 -0.50695115 -0.012244531 0.47494888 0.3235831 0.021464685 -0.48237133 -0.586465 -0.38168398; -0.40725645 -0.26300266 -0.14521688 0.020233944 -0.07136398 -0.56981426 -0.05533645 0.16115816 -0.4485389 -0.56794554; -0.37900218 -0.08815088 0.10154217 0.558363 -0.22744176 0.12258495 0.18857977 -0.16126387 -0.45260283 -0.54091734; -0.47956002 -0.27310026 -0.43743765 0.032916818 0.095131814 -0.6059501 -0.40490097 0.43668085 -0.31058735 -0.21437271; -0.031416014 0.21674222 0.485597 -0.3657828 -0.24838457 0.52909964 0.44705272 0.16652822 0.5047817 -0.5061942], [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0])
type ps: Matrix{Float32}

This produces an error when trying to update the NN's parameters:

Flux.update!(opt, ps_UA, ∇_UA)

ERROR: DimensionMismatch("cannot broadcast array to have fewer dimensions")
Stacktrace:
  [1] check_broadcast_shape(#unused#::Tuple{}, Ashp::Tuple{Base.OneTo{Int64}})
    @ Base.Broadcast ./broadcast.jl:518
  [2] check_broadcast_shape(shp::Tuple{Base.OneTo{Int64}}, Ashp::Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}})
    @ Base.Broadcast ./broadcast.jl:521
  [3] check_broadcast_axes
    @ ./broadcast.jl:523 [inlined]
  [4] check_broadcast_axes
    @ ./broadcast.jl:526 [inlined]
  [5] instantiate
    @ ./broadcast.jl:269 [inlined]
  [6] materialize!
    @ ./broadcast.jl:894 [inlined]
  [7] materialize!
    @ ./broadcast.jl:891 [inlined]
  [8] apply!(o::ADAM, x::Matrix{Float32}, Δ::Vector{Float64})
    @ Flux.Optimise ~/.julia/packages/Flux/qp1gc/src/optimise/optimisers.jl:181
  [9] update!(opt::ADAM, x::Matrix{Float32}, x̄::Vector{Float64})
    @ Flux.Optimise ~/.julia/packages/Flux/qp1gc/src/optimise/train.jl:23
 [10] update!(opt::ADAM, xs::Params, gs::Zygote.Grads)
    @ Flux.Optimise ~/.julia/packages/Flux/qp1gc/src/optimise/train.jl:29

Here's some more context from a Discourse thread: https://discourse.julialang.org/t/unrecognized-gradient-using-zygote-for-ad-with-universal-differential-equations/63791/4

For now I cannot provide a MWE, as this is part of a big model and it is quite tricky to simplify.

Is there any temporary workaround to this issue? Thanks!

mcabbott commented 3 years ago

It is unfortunately easy for Zygote to turn vectors into 1-column matrices, often via things like ([1,2,3]' .^ 2)' isa AbstractMatrix. They're all bugs but it's hard to catch every one.

I also think Flux.apply! should simply accept this. Trivial dimensions shouldn't matter to it. That would probably be quite simple to fix.

In Julia 1.7, ones(3) .= rand(3,1) is not an error anymore. So it's possible that may solve the problem, too.

JordiBolibar commented 3 years ago

A quick patch suggested by @mcabbott provided a workaround for this. But this should probably be handled internally by Zygote.

This solves it for now:

Flux.Optimise.update!(opt, x::AbstractMatrix, Δ::AbstractVector) = Flux.Optimise.update!(opt, x, reshape(Δ, size(x)))
DhairyaLGandhi commented 3 years ago

Well, the language semantics should be respected in such cases. That way these "bugs" are effectively choices made which may otherwise have subtle repercussions. In this case there might have been different code paths for AbstractMatrix vs AbstractVector.

darsnack commented 3 years ago

Looking into this more, I was not able to reproduce the bug.

julia> # Leaky ReLu as activation function
       leakyrelu(x, a=1f-2) = max(a*x, x)
leakyrelu (generic function with 2 methods)

julia> # Constrains A within physically plausible values
       relu_A(x) = min(max(1.58f-17, x), 1.58f-16)
relu_A (generic function with 1 method)

julia> m = Chain(
           Dense(1,10,initb = Flux.zeros),
           BatchNorm(10, leakyrelu),
           Dense(10,5,initb = Flux.zeros),
           BatchNorm(5, leakyrelu),
           Dense(5,1, relu_A, initb = Flux.zeros)
       )
Chain(Dense(1, 10), BatchNorm(10, leakyrelu), Dense(10, 5), BatchNorm(5, leakyrelu), Dense(5, 1, relu_A))

julia> gs = gradient(() -> sum(m(Flux.ones(1, 5))), params(m))
Grads(...)

julia> for p in params(m)
       @show size(p), size(gs[p])
       end
(size(p), size(gs[p])) = ((10, 1), (10, 1))
(size(p), size(gs[p])) = ((10,), (10,))
(size(p), size(gs[p])) = ((10,), (10,))
(size(p), size(gs[p])) = ((10,), (10,))
(size(p), size(gs[p])) = ((5, 10), (5, 10))
(size(p), size(gs[p])) = ((5,), (5,))
(size(p), size(gs[p])) = ((5,), (5,))
(size(p), size(gs[p])) = ((5,), (5,))
(size(p), size(gs[p])) = ((1, 5), (1, 5))
(size(p), size(gs[p])) = ((1,), (1,))

I can only assume that something inside your loss function is causing the error then. Can you share that?

JordiBolibar commented 3 years ago

As discussed in here, we have now an MWE ready which reproduces this issue based on a 2D heat equation UDE.

Here's the MWE:

using LinearAlgebra
using Statistics
using Zygote
using PaddedViews
using Flux
using Flux: @epochs
using Tullio

#### Parameters
nx, ny = 100, 100 # Size of the grid
Δx, Δy = 1, 1
Δt = 0.01
t₁ = 1

D₀ = 1
tolnl = 1e-4
itMax = 100
damp = 0.85
dτsc   = 1.0/3.0
ϵ     = 1e-4            # small number
cfl  = max(Δx^2,Δy^2)/4.1

A₀ = 1
ρ = 9
g = 9.81
n = 3
p = (Δx, Δy, Δt, t₁, ρ, g, n)  # we add extra parameters for the nonlinear diffusivity

### Reference dataset for the heat Equations
T₀ = [ 250 * exp( - ( (i - nx/2)^2 + (j - ny/2)^2 ) / 300 ) for i in 1:nx, j in 1:ny ];
T₁ = copy(T₀);

#######   FUNCTIONS   ############

# Utility functions
@views avg(A) = 0.25 * ( A[1:end-1,1:end-1] .+ A[2:end,1:end-1] .+ A[1:end-1,2:end] .+ A[2:end,2:end] )

@views avg_x(A) = 0.5 .* ( A[1:end-1,:] .+ A[2:end,:] )

@views avg_y(A) = 0.5 .* ( A[:,1:end-1] .+ A[:,2:end] )

### Functions to generate reference dataset to train UDE

function Heat_nonlinear(T, A, p)

    Δx, Δy, Δt, t₁, ρ, g, n = p

    #### NEW CODE TO BREAK
    dTdx = diff(T, dims=1) / Δx
    dTdy = diff(T, dims=2) / Δy
    ∇T = sqrt.(avg_y(dTdx).^2 .+ avg_x(dTdy).^2)

    D = A .* avg(T) .* ∇T

    dTdx_edges = diff(T[:,2:end - 1], dims=1) / Δx
    dTdy_edges = diff(T[2:end - 1,:], dims=2) / Δy

    Fx = -avg_y(D) .* dTdx_edges
    Fy = -avg_x(D) .* dTdy_edges   

    F = .-(diff(Fx, dims=1) / Δx .+ diff(Fy, dims=2) / Δy) 

    dτ = dτsc * min.( 10.0 , 1.0./(1.0/Δt .+ 1.0./(cfl./(ϵ .+ avg(D)))))

    return F, dτ

end

# Fake law to create reference dataset and to be learnt by the NN
fakeA(t) = A₀ * exp(2t)

### Heat equation based on a fake A parameter function to compute the diffusivity
function heatflow_nonlinear(T, fA, p, fake, tol=Inf)

    Δx, Δy, Δt, t₁, ρ, g, n = p

    total_iter = 0
    t = 0

    while t < t₁

        iter = 1
        err = 2 * tolnl
        Hold = copy(T)
        dTdt = zeros(nx, ny)
        err = Inf 

        if fake
            A = fA(t)  # compute the fake A value involved in the nonlinear diffusivity
        else
            # Compute A with the NN once per time step
            A = fA([t]')[1]  # compute A parameter involved in the diffusivity
        end

        while iter < itMax+1 && tol <= err

            Err = copy(T)

            F, dτ = Heat_nonlinear(T, A, p)

            @tullio ResT[i,j] := -(T[i,j] - Hold[i,j])/Δt + F[pad(i-1,1,1),pad(j-1,1,1)] 

            dTdt_ = copy(dTdt)
            @tullio dTdt[i,j] := dTdt_[i,j]*damp + ResT[i,j]

            T_ = copy(T)
            #@tullio T[i,j] := max(0.0, T_[i,j] + dTdt[i,j]*dτ[pad(i-1,1,1),pad(j-1,1,1)]) 
            @tullio T[i,j] := max(0.0, T_[i,j] + dTdt[i,j]*dτ[pad(i-1,1,1),pad(j-1,1,1)])

            Zygote.ignore() do
                Err .= Err .- T
                err = maximum(Err)
            end 

            iter += 1
            total_iter += 1

        end

        t += Δt

    end

    if(!fake)
        println("Values of UA in heatflow_nonlinear: ", fA([0., .5, 1.]')) # Simulations here are correct
    end

    return T

end

# Patch suggested by Michael Abbott needed in order to correctly retrieve gradients
Flux.Optimise.update!(opt, x::AbstractMatrix, Δ::AbstractVector) = Flux.Optimise.update!(opt, x, reshape(Δ, size(x)))

function train(loss, p)

    leakyrelu(x, a=0.01) = max(a*x, x)
    relu(x) = max(0, x)

    UA = Chain(
        Dense(1,10,initb = Flux.glorot_normal), 
        BatchNorm(10, leakyrelu),
        Dense(10,5,initb = Flux.glorot_normal), 
        BatchNorm(5, leakyrelu),
        Dense(5,1, relu, initb = Flux.glorot_normal) 
    )

    opt = RMSProp()
    losses = []
    @epochs 10 hybrid_train_NN!(loss, UA, p, opt, losses)

    println("Values of UA in train(): ", UA([0., .5, 1.]'))

    return UA, losses

end

function hybrid_train_NN!(loss, UA, p, opt, losses)

    T = T₀
    θ = Flux.params(UA)
    println("Values of UA in hybrid_train BEFORE: ", UA([0., .5, 1.]'))
    loss_UA, back_UA = Zygote.pullback(() -> loss(T, UA, p), θ)
    push!(losses, loss_UA)

    ∇_UA = back_UA(one(loss_UA))

    for ps in θ
       println("Gradients ∇_UA[ps]: ", ∇_UA[ps])
    end

    println("θ: ", θ) # parameters are NOT NaNs
    println("Values of UA in hybrid_train AFTER: ", UA([0., .5, 1.]')) # Simulations here are all NaNs

    Flux.Optimise.update!(opt, θ, ∇_UA)

end

function loss_NN(T, UA, p, λ=1)

    T = heatflow_nonlinear(T, UA, p, false)
    l_cost = sqrt(Flux.Losses.mse(T, T_ref; agg=mean))

    return l_cost 
end

#######################

########################################
#####  TRAIN 2D HEAT EQUATION PDE  #####
########################################

T₂ = copy(T₀)
# Reference temperature dataset
T_ref = heatflow_nonlinear(T₂, fakeA, p, true, 1e-1)

# Train heat equation UDE
UA_trained, losses = train(loss_NN, p)

There are currently 2 issues with this so far:

  1. Once the line containing the patch by @mcabbott is commented, the error regarding the size mismatch of the model parameters is reproduced.
  2. When using the patch, the neural network (UA) seems to be only working inside Zygote.pullback(). This can be seen through the logs I added to the MWE. The parameters outside the pullback seem to be fine, and they are correctly updated, but UA is somehow broken and doesn't seem to be linked to the implicit parameters anymore.

Any ideas on what might be the culprit? Is there any workaround for this?

Thanks again!

DhairyaLGandhi commented 3 years ago

It would be awesome to reduce this example further, but iiuc the dimension mismatch should be solved in julia 1.7

JordiBolibar commented 3 years ago

Yes, I was aware about the potential fix with Julia 1.7, but still that patch seems to cover it for now.

There is a remaining issue though, as I explained above, the Flux model doesn't seem to work outside the pullback. Is this a bug or am I doing something wrong?

I have also trimmed a little bit more the code. Still pretty long but it should reproduce it with a copy/paste :)