Open mhauru opened 1 week ago
Looks like the constructor for BitVector
is fairly involved. You could just using ChainRules to @non_differentiable
it, e.g.
@non_differentiable BitVector(a, b)
in a fresh session seems to work okay for me locally. It seems reasonable to me that you wouldn't be able to drop any gradient info doing this, so it should be safe.
If that works, you can make a 1-line PR to this file https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/Base/nondiff.jl to fix it permanently.
The above fails with
on v0.6.70.
Switching to e.g.
Vector{Bool}
rather than aBitVector
works.