FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.49k stars 211 forks source link

batched matrix multiplication #541

Open vmoens opened 4 years ago

vmoens commented 4 years ago

Would be useful to have the adjoint for batched matrix multiplication (which causes foreigncall error now). e.g.

@adjoint function batched_mul(A,B)
     batched_mul(A,B), Δ -> (batched_mul(Δ , batched_transpose(B)), batched_mul(batched_transpose(A) , Δ))
end
mcabbott commented 4 years ago

This should be working in v0.4.9, added in #531.

Roger-luo commented 4 years ago

I'm working on a new BatchedArray implementation BTW. The rules are defined inside too, you could just use this package: https://github.com/Roger-luo/BatchedArrays.jl (it's not as mature as NNlib) but could be faster.