Open denizyuret opened 3 years ago
This does not work:
function relu(x::T; max_value=Inf, negative_slope=0, threshold=0) where T (x >= max_value ? oftype(x, max_value) : x >= threshold ? x : negative_slope == 0 ? zero(T) : negative_slope * (x - oftype(x, threshold))) end function reluback(x::T,dy::T; max_value=Inf, negative_slope=0, threshold=0) where T dy * (x >= max_value ? zero(T) : x >= threshold ? one(T) : oftype(x, negative_slope)) end @primitive relu(x; o...),dy,y reluback.(x,dy; o...) @primitive reluback(x,dy; o...),ddx,dx zero(dy)
Gives the error: MethodError: no method matching broadcasted(::typeof(relu), ::Array{Float64, 4}; threshold=1)
MethodError: no method matching broadcasted(::typeof(relu), ::Array{Float64, 4}; threshold=1)
Because it turns out these are lowered differently:
julia> Meta.@lower relu.(a; threshold=1) :($(Expr(:thunk, CodeInfo( @ none within `top-level scope' 1 ─ %1 = Base.broadcasted_kwsyntax │ %2 = Core.tuple(:threshold) │ %3 = Core.apply_type(Core.NamedTuple, %2) │ %4 = Core.tuple(1) │ %5 = (%3)(%4) │ %6 = Core.kwfunc(%1) │ %7 = (%6)(%5, %1, relu, a) │ %8 = Base.materialize(%7) └── return %8 ))))
This does not work:
Gives the error:
MethodError: no method matching broadcasted(::typeof(relu), ::Array{Float64, 4}; threshold=1)
Because it turns out these are lowered differently: