FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.53k stars 610 forks source link

Issues model with custom gradient (w.r.t. input variable) layer #1760

Closed zsteve closed 3 years ago

zsteve commented 3 years ago

Hi,

I'm trying to set up a DiffEqFlux model where there I have custom layer that computes the gradient (w.r.t. the input variable) of a scalar-valued neural network, but I am encountering an error when I try to train the model (presumably because gradients w.r.t. parameters are wanted).

Update: I just saw #1518 and it seems that Zygote has issues with nesting AD. However, I do not think that it works to just take derivatives to both inputs and parameters at the same time, because the gradient w.r.t. input is fed into a SDE solver. Any help or pointers would be appreciated!

The code I've got is below, where I'm using Flux.gradient to produce the custom gradient layer (I've also tried ForwardDiff, which throws a different error suggesting that ForwardDiff can't be used within Zygote). I am using FastChain as suggested here to allow the parameters from the potential to pass through to the drift.

using Flux
using DifferentialEquations
using DiffEqFlux
using ReverseDiff
using ForwardDiff

# set up 1d potential
potential_dudt = FastChain(FastDense(1, 50, tanh), FastDense(50, 5, tanh), FastDense(5, 5), (x, p) -> sum(x))
λ = 0.1
# I tried also using ForwardDiff, but it looks like ForwardDiff can't be used within Zygote
# grad_potential_dudt = (x, p) -> -ForwardDiff.gradient(z -> potential_dudt(z, p) + λ*norm(z)^2, x)[1]
grad_potential_dudt = (x, p) -> -Flux.gradient(z -> potential_dudt(z, p) + λ*norm(z)^2, x)[1]
DiffEqFlux.paramlength(::typeof(grad_potential_dudt)) = length(p)
DiffEqFlux.initial_params(::typeof(grad_potential_dudt)) = randn(size(p))
drift_dudt = FastChain(grad_potential_dudt)
diffusion_dudt = FastChain((x, p) -> 0*x .+ 1.)

tspan = (0, 1.)
neuralsde = NeuralDSDE(drift_dudt, diffusion_dudt, tspan, EM(),
                       dt = 0.01, reltol = 1e-1, abstol = 1e-1)

Seems that the parameters are working out for neuralsde object:

Flux.params(neuralsde)

> Params([[0.11302589034050317, -1.1522473763685657, 0.7929682576183352, 0.3861498137472427, 0.07313114051822761, 1.2341699100342243, 2.7601624470776143, -1.7536057888821417, 1.6026677222841252, -1.4909777278668033  …  0.4153961766998856, 0.5483026680321065, -1.569027760921727, -0.6441666255598087, 0.08935386494230572, 0.2280837310000524, 1.7409343223363576, 0.34413175160205506, -0.11620190817823804, -0.20239467916062465]])

and sol = neuralsde(randn(1, 100)) also works, so sampling from the SDE is not a problem, only training. Suspect this has to do with Zygote trying to AD (w.r.t. network parameters) through the output of Flux.gradient (w.r.t. input variable).

The minimal training code I am using is

using OptimalTransport
using Distances

function loss_neuralsde(p; n = 100)
    ρ0 = Array(neuralsde(randn(1, n), p))[1, :, end]
    ρ1 = u[:, end]
    C01 = pairwise(SqEuclidean(), ρ0, ρ1)
    C11 = pairwise(SqEuclidean(), ρ1, ρ1)
    C00 = pairwise(SqEuclidean(), ρ0, ρ0)
    sinkhorn_divergence(fill(1/length(ρ0), length(ρ0)), fill(1/length(ρ1), length(ρ1)), C01, C00, C11, 1.0)
end

callback = function (p, loss)
    @info "callback"
    false
end

opt = ADAM()
DiffEqFlux.sciml_train((p) -> loss_neuralsde(p, n = 10),  
                                 neuralsde.p, opt,
                                 cb = callback, maxiters = 100)

The error is as below.

MethodError: *(::TrackedArray{…,Matrix{Tracker.TrackedReal{Float64}}}, ::TrackedArray{…,Adjoint{Float64, Matrix{Float64}}}) is ambiguous. Candidates:
  *(a::AbstractMatrix{var"#s550"} where var"#s550"<:Tracker.TrackedReal, b::Tracker.TrackedMatrix{T, A} where A) where T<:Real in DistributionsAD at /home/syz/.julia/packages/DistributionsAD/b93cZ/src/tracker.jl:199
  *(x::Tracker.TrackedMatrix{T, A} where {T, A}, y::Tracker.TrackedMatrix{T, A} where {T, A}) in Tracker at /home/syz/.julia/packages/Tracker/YNNTM/src/lib/array.jl:423
  *(x::AbstractMatrix{T} where T, y::Tracker.TrackedMatrix{T, A} where {T, A}) in Tracker at /home/syz/.julia/packages/Tracker/YNNTM/src/lib/array.jl:422
  *(a::AbstractArray{var"#s550", N} where {var"#s550"<:Tracker.TrackedReal, N}, b::Tracker.TrackedArray{T, N, A} where {N, A<:AbstractArray{T, N}}) where T<:Real in DistributionsAD at /home/syz/.julia/packages/DistributionsAD/b93cZ/src/tracker.jl:199
  *(x::Tracker.TrackedMatrix{T, A} where {T, A}, y::AbstractMatrix{T} where T) in Tracker at /home/syz/.julia/packages/Tracker/YNNTM/src/lib/array.jl:421
Possible fix, define
  *(::Tracker.TrackedMatrix{T, A} where {T<:Tracker.TrackedReal, A}, ::Tracker.TrackedMatrix{T, A} where A<:AbstractMatrix{T}) where T<:Real

Stacktrace:
[...]
ToucheSir commented 3 years ago

This is way too much code for a MWE, and the stacktrace seems reminiscent of some recent issues where Zygote wasn't being used at all (note the presence of the old Flux AD, Tracker). If you're not able to create a diffeq-less MWE, I'd recommend asking in their channels to see if anyone has seen this before (I believe they have) and how best to address it.

zsteve commented 3 years ago

You are right, I should have tried to narrow the scope to exclude diffeq, sorry! I was under the impression I was making some trivial mistake since I'm new to Flux. Will close and re-post a cleaned issue once ready.