EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
449 stars 63 forks source link

Enzyme segfaults on the following MWE #2032

Open ArbitRandomUser opened 2 days ago

ArbitRandomUser commented 2 days ago

run mwe.jl in the uploaded folder.

enzymemwe.zip

or below

ArbitRandomUser commented 2 days ago

mwe.jl

using Enzyme
include("NN.jl")
using LinearAlgebra

const layers = [Dense(4,2*4*4+10,relu!),Dense(2*4*4+10,2*4*4,relu!)]
const layers2 = [Dense(4,2*4*4+10,relu!),Dense(2*4*4+10,2*4*4,softmax!)] #this nn fails to AD

const nn = makeNN(layers)
const nn2 = makeNN(layers2)

##cell
function make_dfunc(nn)
    outp = zeros(32)
    scratch = zeros(8)
    amps = zeros(4)
    function dfunc(du,u,nnparams,t)
        set_nnparams(nn,nnparams)
        applyNN!(nn,amps,outp)
        nothing
    end
    return dfunc,nn
end

u0 = rand(8)
der_u0 = zero(u0)

du0 = zeros(8)
der_du0 = rand(length(du0))

dfunc,_ = make_dfunc(nn)
dfunc2,_ = make_dfunc(nn2)

nnparams = get_nnparams(nn)
nnparams2 = get_nnparams(nn2)

dfunc(du0,u0,nnparams,0.1)

println("first nn")
res1 = autodiff(Reverse,Duplicated(dfunc,make_zero(dfunc)),Duplicated(du0,der_du0),Duplicated(u0,der_u0),Duplicated(nnparams,make_zero(nnparams)),Active(0.1))
println("second nn") #segfault
res2 = autodiff(Reverse,Duplicated(dfunc2,make_zero(dfunc2)),Duplicated(du0,der_du0),Duplicated(u0,der_u0),Duplicated(nnparams,make_zero(nnparams)),Active(0.1))
ArbitRandomUser commented 2 days ago

NN.jl

# fast simple barebones cpu only allocation free dense neural networks 
# use Enzyme for gradients 

import Base.zero

function tanh!(ret::Array,x::Array)
    ret .= tanh.(x)
    nothing
end

function relu!(ret,x)
    ret .= max.(0.0, x)
    nothing
end

function softmax!(ret,x)
    ret .= exp.(x)
    ret .= ret ./ sum(ret)
    nothing
end

struct Dense{T,F<:Function}
    n_inp::Int
    n_nodes::Int
    W::Matrix{T}
    b::Vector{T}
    activation::F
end

"""
    dense layer ,
    f ::activation (should take arguments (ret,inp) and store outputs on ret. check `relu` for more details 
    randfn :: random function called randfn(a,b) used to initialize the layers matrix 
"""
function Dense(n_inp, n_nodes, f::Function, randfn::Function = rand)
    Dense(n_inp, n_nodes, randfn(n_nodes, n_inp), randfn(n_nodes), f)
end

struct NN{T,L<:Tuple}
    n_inp::Int
    layers::L # Tuple of Dense
    intermediates::Vector{Vector{T}} # preallocated vectors for output of layers
end

"""
    make an NN , consequent layers must have matching inputs and number of nodes
    (i.e n_nodes of i'th layer == n_inp of i+1th layer) 
    #TODO automate this to be nicer. 
"""
function makeNN(n_inp, layers::Array, T::Type = Float64)
    @assert length(layers) >= 1
    @assert n_inp == layers[1].n_inp
    """ assert consecutive layers match in input and nodes"""
    for i in eachindex(layers)[1:end-1]
        @assert layers[i].n_nodes == layers[i+1].n_inp
    end
    NN(n_inp, Tuple(layers), Vector{T}[zeros(layer.n_nodes) for layer in layers])
end

function makeNN(layers::Array,T::Type=Float64)
    makeNN(layers[1].n_inp,layers,T)
end

"""
    get number of parameters in the nn
"""
function paramlength(nn::NN)
    r = 0
    for l in nn.layers
        r = r + length(l.W)
        r = r + length(l.b)
    end
    return r
end

"""
    get the parameters of the nn flattened in an array
"""
function get_nnparams(nn::NN)
    ret = Float64[]
    for l in nn.layers
        append!(ret, l.W)
        append!(ret, l.b)
    end
    return ret
end

function set_denseparams(d::Dense,arr)
    d.W .= reshape(view(arr,1:length(d.W)),size(d.W))
    d.b .= view(arr,length(d.W)+1:length(d.W)+1+length(d.b)-1)
end

"""
    set a flattened array of params to nn. (possibly not type stable if layers have different activations)
    Note, This does not error if params is larger than number of params of the nn.
    (we dont assert because this could be be part of a hotloop)
"""
function set_nnparams2(nn, nnparams)
    i = 1
    for j in 1:length(nn.layers)
        ll= nn.layers[j]
        set_denseparams(ll, view(nnparams,i:i+length(ll.W)+length(ll.b)-1) )
        i=i + length(ll.W) + length(ll.b)
    end
    nothing
end

"""
    set a flattened array of params to nn. (this is type stable)
    Note, This does not error if params is larger than number of params of the nn.
"""
@generated function set_nnparams(nn::NN{T, <:NTuple{N, Any}}, nnparams) where {T, N}
    quote    
    i = 1
    Base.Cartesian.@nexprs $N j -> begin
        #l = nn.layers[j]
        #l.W  .= reshape(view(nnparams,i:(i+length(l.W)-1)),size(l.W))
        #i = i + length(l.W)
        #l.b .= view(nnparams,i:(i+length(l.b)-1))
        #i = i + length(l.b)
        ll= nn.layers[j]
        set_denseparams(ll, view(nnparams,i:i+length(ll.W)+length(ll.b)-1) )
        i=i + length(ll.W) + length(ll.b)
    end
    nothing
    end
end

"""
    returns a similar nn with all 0 params and intermediates
    (use make_zero instead if making shadow for autodiff)
"""
function Base.zero(nn::NN)
    newnn = deepcopy(nn)
    for l in newnn.layers
        l.W .= 0.0
        l.b .= 0.0
    end
    for inter in newnn.intermediates
        inter .= 0.0
    end
    return newnn
end

"""
    apply dense layer on inp and store the result in out.
    inp : a vector of d.inp size.
    out : a vector of d.nodes size.
    note! uses mul!, `inp` and `out` should not be aliased.
"""
function applydense!(d::Dense, inp, out)
    mul!(out, d.W, inp, 1.0, 0.0)
    out .+= d.b
    d.activation(out,out)
    nothing
end

"""
    apply neural network nn on vector `inp` and store result in `out`
"""
function applyNN!(nn::NN, inp, out)
    applydense!(nn.layers[1], inp, nn.intermediates[1])
    for i in eachindex(nn.layers)[2:end]
        applydense!(nn.layers[i], nn.intermediates[i-1], nn.intermediates[i])
    end
    out .= nn.intermediates[end]
    nothing
end
ArbitRandomUser commented 2 days ago

Project.toml

[deps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
wsmoses commented 1 day ago

what julia version did you use?

ArbitRandomUser commented 1 day ago

1.10.5

julia> versioninfo()
Julia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 8 × Intel(R) Xeon(R) CPU E3-1270 v5 @ 3.60GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, skylake)
Threads: 1 default, 0 interactive, 1 GC (on 8 virtual cores)