SBuercklin / UnitfulChainRules.jl

ChainRules.jl integration for Unitful.jl
MIT License
13 stars 1 forks source link

Fractional powers not supported #10

Closed bks-nist closed 2 years ago

bks-nist commented 2 years ago

Sometimes one ends up using Quantities with fractional powers, but they don't currently work with UnitfulChainRules.jl:

julia> Zygote.pullback(x -> x^(1//3), 3.0u"W^3")[2](1)
ERROR: MethodError: no method matching log(::Quantity{ComplexF64, 𝐋 ^6 𝐌 ^3 𝐓 ^-9, Unitful.FreeUnits{(W^3,), 𝐋 ^6 𝐌 ^3 𝐓^-9, nothing}})
Closest candidates are:
  log(::T, ::T) where T<:Number at C:\Program Files\Julia-1.7.0\share\julia\base\math.jl:315
  log(::Number, ::Number) at C:\Program Files\Julia-1.7.0\share\julia\base\math.jl:358
  log(::RoundingMode, ::ForwardDiff.Dual{Ty}) where Ty at C:\Users\bks1\.julia\packages\ForwardDiff\wAaVJ\src\dual.jl:145
  ...
Stacktrace:
  [1] _pow_grad_p(x::Quantity{Float64, 𝐋^6 𝐌^3 𝐓^-9, Unitful.FreeUnits{(W^3,), 𝐋^6 𝐌^3 𝐓^-9, nothing}}      , p::Rational{Int64}, y::Quantity{Float64, 𝐋^2 𝐌 𝐓^-3, Unitful.FreeUnits{(W,), 𝐋^2 𝐌 𝐓^-3, nothing}}      )
    @ ChainRules C:\Users\bks1\.julia\packages\ChainRules\EyLkg\src\rulesets\Base\fastmath_able.jl:320
  [2] (::ChainRules.var"#1244#1277"{Int64, Quantity{Float64, 𝐋^6 𝐌^3 𝐓^-9, Unitful.FreeUnits{(W^3,), 𝐋    ^6 𝐌^3 𝐓^-9,
 nothing}}, Rational{Int64}, ProjectTo{Real, NamedTuple{(), Tuple{}}}, Quantity{Float64, 𝐋^2 𝐌 𝐓^   -3, Unitful.FreeUnits{(W,), 𝐋^2 𝐌 𝐓^-3, nothing}}})   ()
    @ ChainRules C:\Users\bks1\.julia\packages\ChainRules\EyLkg\src\rulesets\Base\fastmath_able.jl:196
  [3] unthunk
    @ C:\Users\bks1\.julia\packages\ChainRulesCore\ctmSK\src\tangent_types\thunks.jl:199 [inlined]
  [4] wrap_chainrules_output
    @ C:\Users\bks1\.julia\packages\Zygote\IoW2g\src\compiler\chainrules.jl:104 [inlined]
  [5] map
    @ .\tuple.jl:223 [inlined]
  [6] wrap_chainrules_output
    @ C:\Users\bks1\.julia\packages\Zygote\IoW2g\src\compiler\chainrules.jl:105 [inlined]
  [7] ZBack
    @ C:\Users\bks1\.julia\packages\Zygote\IoW2g\src\compiler\chainrules.jl:205 [inlined]
  [8] Pullback
    @ .\REPL[51]:1 [inlined]
  [9] (::typeof(βˆ‚(#111)))(Ξ”::Int64)
    @ Zygote C:\Users\bks1\.julia\packages\Zygote\IoW2g\src\compiler\interface2.jl:0
 [10] (::Zygote.var"#60#61"{typeof(βˆ‚(#111))})(Ξ”::Int64)
    @ Zygote C:\Users\bks1\.julia\packages\Zygote\IoW2g\src\compiler\interface.jl:41
 [11] top-level scope
    @ REPL[51]:1
SBuercklin commented 2 years ago

An extra method for _pow_grad_p would solve this, but then we have to depend on ChainRules.jl.

Instead, a custom rrule for ^ is a better solution. We can probably just copy it over from here and write our own _pow_grad_{x,p} methods to handle the math that are compatible with Unitful.jl

SBuercklin commented 2 years ago

@bks-nist Should be fixed on main. I have a couple more rules I want to include and then I'll tag a new release

bks-nist commented 2 years ago

That's great, thank you!