denizyuret / AutoGrad.jl

Julia port of the Python autograd package.
Other
169 stars 26 forks source link

keyword args in broadcasted funcs not supported #124

Open denizyuret opened 3 years ago

denizyuret commented 3 years ago

This does not work:

function relu(x::T; max_value=Inf, negative_slope=0, threshold=0) where T                                                                                             
    (x >= max_value ? oftype(x, max_value) :                                                                                                                          
     x >= threshold ? x :                                                                                                                                             
     negative_slope == 0 ? zero(T) :                                                                                                                                  
     negative_slope * (x - oftype(x, threshold)))                                                                                                                     
end                                                                                                                                                                   

function reluback(x::T,dy::T; max_value=Inf, negative_slope=0, threshold=0) where T                                                                                   
    dy * (x >= max_value ? zero(T) :                                                                                                                                  
          x >= threshold ? one(T) :                                                                                                                                   
          oftype(x, negative_slope))                                                                                                                                  
end                                                                                                                                                                   

@primitive relu(x; o...),dy,y  reluback.(x,dy; o...)                                                                                                                  
@primitive reluback(x,dy; o...),ddx,dx  zero(dy)                                                                                                                      

Gives the error: MethodError: no method matching broadcasted(::typeof(relu), ::Array{Float64, 4}; threshold=1)

Because it turns out these are lowered differently:

julia> Meta.@lower relu.(a; threshold=1)                                                                                                                              
:($(Expr(:thunk, CodeInfo(                                                                                                                                            
    @ none within `top-level scope'                                                                                                                                   
1 ─ %1 = Base.broadcasted_kwsyntax                                                                                                                                    
│   %2 = Core.tuple(:threshold)                                                                                                                                       
│   %3 = Core.apply_type(Core.NamedTuple, %2)                                                                                                                         
│   %4 = Core.tuple(1)                                                                                                                                                
│   %5 = (%3)(%4)                                                                                                                                                     
│   %6 = Core.kwfunc(%1)                                                                                                                                              
│   %7 = (%6)(%5, %1, relu, a)                                                                                                                                        
│   %8 = Base.materialize(%7)                                                                                                                                         
└──      return %8                                                                                                                                                    
))))