FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 210 forks source link

BitVector failure #1523

Open mhauru opened 1 week ago

mhauru commented 1 week ago
using Zygote: Zygote

struct VNV{TVal}
    vals::TVal
    bv::BitVector
end

f(x) = VNV(x, BitVector(undef, 1)).vals
Zygote.pullback(f, [1.0])

The above fails with

ERROR: LoadError: ArgumentError: new: too few arguments (expected 3)
Stacktrace:
  [1] __new__
    @ ~/.julia/packages/Zygote/nsBv0/src/tools/builtins.jl:9 [inlined]
  [2] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:296 [inlined]
  [3] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
  [4] BitArray
    @ ./bitarray.jl:39 [inlined]
  [5] _pullback(::Zygote.Context{false}, ::Type{BitVector}, ::UndefInitializer, ::Int64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [6] f
    @ ~/projects/DynamicPPL.jl/tmp_zygote_bug.jl:10 [inlined]
  [7] _pullback(ctx::Zygote.Context{false}, f::typeof(Main.TmpZygoteBug.f), args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [8] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:90
  [9] pullback(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:88
 [10] top-level scope
    @ ~/projects/DynamicPPL.jl/tmp_zygote_bug.jl:11
 [11] include(fname::String)
    @ Base.MainInclude ./client.jl:489
 [12] top-level scope
    @ REPL[3]:1
in expression starting at /Users/mhauru/projects/DynamicPPL.jl/tmp_zygote_bug.jl:1

on v0.6.70.

Switching to e.g. Vector{Bool} rather than a BitVector works.

willtebbutt commented 1 week ago

Looks like the constructor for BitVector is fairly involved. You could just using ChainRules to @non_differentiable it, e.g.

@non_differentiable BitVector(a, b)

in a fresh session seems to work okay for me locally. It seems reasonable to me that you wouldn't be able to drop any gradient info doing this, so it should be safe.

mcabbott commented 1 week ago

If that works, you can make a 1-line PR to this file https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/Base/nondiff.jl to fix it permanently.