EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
443 stars 62 forks source link

Regression in Broadcasting support #1476

Closed avik-pal closed 4 months ago

avik-pal commented 4 months ago
using Enzyme, Statistics

function _affine_normalize(::typeof(identity), x::AbstractArray, xmean, xvar,
    scale::AbstractArray, bias::AbstractArray, epsilon::Real)
    _scale = @. scale / sqrt(xvar + epsilon)
    _bias = @. bias - xmean * _scale
    return @. x * _scale + _bias
end

function loss_function(x, scale, bias)
    x_ = reshape(x, 6, 6, 3, 2, 2)
    scale_ = reshape(scale, 1, 1, 3, 2, 1)
    bias_ = reshape(bias, 1, 1, 3, 2, 1)

    xmean = mean(x_, dims=(1, 2, 5))
    xvar = var(x_, corrected=false, mean=xmean, dims=(1, 2, 5))

    return sum(abs2, _affine_normalize(identity, x_, xmean, xvar, scale_, bias_, 1e-5))
end

x = rand(Float32, 6, 6, 6, 2)
sc = rand(Float32, 6)
bi = rand(Float32, 6)

loss_function(x, sc, bi)

Enzyme.autodiff(Reverse, loss_function, Active, Duplicated(x, Enzyme.make_zero(x)),
    Duplicated(sc, Enzyme.make_zero(sc)), Duplicated(bi, Enzyme.make_zero(bi)))
ERROR: DimensionMismatch: destination axes (Base.OneTo(1), Base.OneTo(1), Base.OneTo(3), Base.OneTo(2), Base.OneTo(1)) are not compatible with source axes (Base.OneTo(1), Base.OneTo(1), Base.OneTo(3), Base.OneTo(2), Base.OneTo(1))
Stacktrace:
  [1] copyto!
    @ ./broadcast.jl:992 [inlined]
  [2] copyto!
    @ ./broadcast.jl:956 [inlined]
  [3] copy
    @ ./broadcast.jl:928 [inlined]
  [4] materialize
    @ ./broadcast.jl:903 [inlined]
  [5] _affine_normalize
    @ /mnt/research/lux/enzyme_bcast.jl:46
  [6] loss_function
    @ /mnt/research/lux/enzyme_bcast.jl:59 [inlined]
  [7] diffejulia_loss_function_15319wrap
    @ /mnt/research/lux/enzyme_bcast.jl:0
  [8] macro expansion
    @ ~/.julia/packages/Enzyme/UZsMX/src/compiler.jl:5855 [inlined]
  [9] enzyme_call
    @ ~/.julia/packages/Enzyme/UZsMX/src/compiler.jl:5521 [inlined]
 [10] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/UZsMX/src/compiler.jl:5400 [inlined]
 [11] autodiff
    @ ~/.julia/packages/Enzyme/UZsMX/src/Enzyme.jl:291 [inlined]
 [12] autodiff(::ReverseMode{…}, ::typeof(loss_function), ::Type{…}, ::Duplicated{…}, ::Duplicated{…}, ::Duplicated{…})
    @ Enzyme ~/.julia/packages/Enzyme/UZsMX/src/Enzyme.jl:303
 [13] top-level scope
    @ /mnt/research/lux/enzyme_bcast.jl:68
Some type information was truncated. Use `show(err)` to see complete types.

This used to work before v0.12.8

avik-pal commented 4 months ago
using Enzyme

@noinline function _bcs2(x, y)
    x != y && error(2)
    return x
end

@noinline function _affine_normalize(x::AbstractArray)
    # _axes = broadcast_shape(axes(x), axes(x)) #Broadcast.combine_axes(x, x)
    _axes = _bcs2(axes(x), axes(x))
    i = Broadcast.Broadcasted(Base.Broadcast.DefaultArrayStyle{2}(), +, (x,), _axes)

    dest = similar(Array{Float32}, _axes)
    bc = convert(Broadcast.Broadcasted{Nothing}, i)

    # mycopyto!(dest, bc)
    copyto!(dest, bc)
    return x
end

function loss_function(x)
    return _affine_normalize(x)[1]
end

x = rand(Float32, 2, 3)

loss_function(x)

Enzyme.gradient(Reverse, loss_function, x)
wsmoses commented 4 months ago

Fixed in main