JuliaArrays / LazyArrays.jl

Lazy arrays and linear algebra in Julia
MIT License
304 stars 25 forks source link

Zygote compat is lacking #232

Open torfjelde opened 1 year ago

torfjelde commented 1 year ago

Zygote doesn't interact too nicely with LazyArrays.jl it seems, e.g.:

julia> f(x) = sum(BroadcastArray(exp, x))
f (generic function with 1 method)

julia> Zygote.gradient(f, randn(10))
ERROR: type Array has no field f
Stacktrace:
  [1] adjoint
    @ ~/.julia/packages/Zygote/AS0Go/src/lib/lib.jl:229 [inlined]
  [2] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
  [3] _pullback
    @ ~/.julia/packages/LazyArrays/NYra8/src/lazyapplying.jl:50 [inlined]
  [4] _pullback(::Zygote.Context{false}, ::typeof(LazyArrays.call), ::ArrayLayouts.DenseColumnMajor, ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface2.jl:0
  [5] _pullback
    @ ~/.julia/packages/LazyArrays/NYra8/src/lazyapplying.jl:52 [inlined]
  [6] _pullback
    @ ~/.julia/packages/LazyArrays/NYra8/src/lazybroadcasting.jl:82 [inlined]
  [7] _pullback
    @ ~/.julia/packages/LazyArrays/NYra8/src/lazybroadcasting.jl:57 [inlined]
  [8] _pullback(::Zygote.Context{false}, ::Type{BroadcastArray}, ::typeof(exp), ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface2.jl:0
  [9] _pullback
    @ ./REPL[48]:1 [inlined]
 [10] _pullback(ctx::Zygote.Context{false}, f::typeof(f), args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface2.jl:0
 [11] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface.jl:44
 [12] pullback
    @ ~/.julia/packages/Zygote/AS0Go/src/compiler/interface.jl:42 [inlined]
 [13] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface.jl:96
 [14] top-level scope
    @ REPL[50]:1

julia> g(x) = sum(LazyArray(@~ exp.(x)))
g (generic function with 1 method)

julia> Zygote.gradient(g, randn(10))
ERROR: MethodError: no method matching LazyArray(::Vector{Float64})
Closest candidates are:
  LazyArray(::Base.Broadcast.Broadcasted) at ~/.julia/packages/LazyArrays/NYra8/src/lazybroadcasting.jl:35
  LazyArray(::Applied) at ~/.julia/packages/LazyArrays/NYra8/src/lazyapplying.jl:193
Stacktrace:
 [1] macro expansion
   @ ~/.julia/packages/Zygote/AS0Go/src/compiler/interface2.jl:0 [inlined]
 [2] _pullback(ctx::Zygote.Context{false}, f::Type{LazyArray}, args::Vector{Float64})
   @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface2.jl:9
 [3] _pullback
   @ ./REPL[53]:1 [inlined]
 [4] _pullback(ctx::Zygote.Context{false}, f::typeof(g), args::Vector{Float64})
   @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface2.jl:0
 [5] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
   @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface.jl:44
 [6] pullback
   @ ~/.julia/packages/Zygote/AS0Go/src/compiler/interface.jl:42 [inlined]
 [7] gradient(f::Function, args::Vector{Float64})
   @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface.jl:96
 [8] top-level scope
   @ REPL[54]:1

The first error can be "fixed" (I'm not entirely certain if this is the right way to go about it) by defining a chain rule:

julia> using ChainRulesCore

julia> function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::Type{LazyArrays.BroadcastArray}, f, args...)
           return ChainRulesCore.rrule_via_ad(config, Broadcast.broadcasted, f, args...)
       end

julia> Zygote.refresh()

julia> Zygote.gradient(f, randn(10))
([0.24117702568683322, 2.478340448616497, 2.433266795642693, 1.6163793920298133, 1.8859252985478665, 3.9539878829654223, 1.2578105524502685, 0.48545348574922, 0.8710494256114425, 3.0853524634917076],)

Maybe the rest can be addressed this way too.

Are rules from CRC something that would be welcomed?

dlfivefifty commented 1 year ago

Hmm.... that's a good question.... I'm usually hesitant to add "*Core.jl" dependencies because a lot of them are of questionable usage but ChainRulesCore.jl might be an exception.

One alternative solution is to make a glue package a la FastTransformsForwardDiff. (I'm wondering whether that should have been FastTransformsChainRulesCore.jl...)

torfjelde commented 1 year ago

Either alternative is okay with me:)

torfjelde commented 1 year ago

You just say which alternative you prefer, and I can try to contribute towards it.

dlfivefifty commented 1 year ago

Let's put it in a separate package for now so we can work out the kinks. We can always merge it back here (in the event there's a good reason to have it).

devmotion commented 1 year ago

It seems this is a good use case for weak deps. Some packages already started moving ChainRules definition to weak deps. The definitions would be loaded only on Julia >= 1.9 (if you don't want to uae Requires on older Julia versions) but I think it would be the better long-term solution.

torfjelde commented 1 year ago

It woul suck if we'd have to wait until Julia 1.9 before we could make use of this though :confused:

devmotion commented 1 year ago

I assume it already works with the beta version, so I think you can already use it without compiling julia.

dlfivefifty commented 1 year ago

Can we do a separate package that works now, and becomes a weak dependency in Julia v1.9?

devmotion commented 1 year ago

If a weak dependency is loaded, an extension (usually a single file) in the ext subfolder is loaded (and precompiled, in contrast to the Requires hacks!). AFAIK there are no separate packages involved or loaded in the extension apart from the weak dependency and the package + hard dependencies, and making the glue package a hard dependency would defeat its purpose. An example is shown in this PR: https://github.com/JuliaMath/ChangesOfVariables.jl/pull/12

dlfivefifty commented 1 year ago

I see. I think a weak dependency hear would be fine. I would suggest forgetting the separate project and just requiring v1.9

oschulz commented 1 year ago

We use weak deps for ChangesOfVariables.jl now, and it works like a charm on Julia v1.9:

julia> @time_imports import ChangesOfVariables
      0.6 ms  ChangesOfVariables

julia> @time_imports import ChainRulesCore
      0.1 ms  Compat
     58.9 ms  ChainRulesCore
      0.4 ms  ChangesOfVariables → ChainRulesCoreExt