Closed avik-pal closed 2 months ago
Technically an assertion error not segfault, but in any case reducing:
using Enzyme, Strided, FastBroadcast
const gelu_2λ = √(8 / π)
function gelu(x)
oftype(float(x), y)
end
@inline function __fast_broadcast!(x)
# if ArrayInterface.fast_scalar_indexing(x)
if length(x) > 200_000
p6 = Strided.maybestrided(x)
p10 = Strided.maybestrided(x)
p11 = Base.broadcasted(gelu, p10)
bc = Base.Broadcast.combine_styles(p6, p11)
# inst = Base.Broadcast.instantiate(Base.Broadcast.Broadcasted(bc.style, bc.f, bc.args, axes(p6)))
Base.materialize!(Base.Broadcast.combine_styles(p6, p11), p6, p11)
end
return x
end
x = randn(Float32, 125)
# Enzyme.gradient(Enzyme.Reverse, x -> sum(__fast_broadcast!(gelu, x)), x)
f(x) = sum(__fast_broadcast!(x))
dx = Enzyme.make_zero(x)
Enzyme.autodiff(Enzyme.Reverse, f, Active, Duplicated(x, dx))
Will be fixed by https://github.com/EnzymeAD/Enzyme/pull/1862 but will require a jll bump
I tried minimizing this further (see the commented parts) but nothing else seems to segfault
Enzyme version is 0.12