MichielStock / Kronecker.jl

A general-purpose toolbox for efficient Kronecker-based algebra.
MIT License
85 stars 14 forks source link

Efficient Kronecker Gradients in Zygote #11

Open jessebett opened 5 years ago

jessebett commented 5 years ago

Can you review the implementation of the gradients for kron in Zygote and Tracker? I directly ported these from Tensorflow.

Specifically could you comment on the implementation, and whether it could benefit from your package?

MichielStock commented 5 years ago

I will look at it Friday during the hackathon!

MichielStock commented 5 years ago

Can we discuss irl somewhere this week? Have some questions.

Btw, kron from Zygote seems to work a bit faster than base.

jessebett commented 5 years ago

Yes. I'm around tomorrow and will also be at the hackathon on Friday working on Zygote stuff. The kron in Zygote is just some reshape rules. I did not benchmark against base, but that's pretty surprising, I wonder why.

samanklesaria commented 4 years ago

The reshape/broadcasting implementation doesn't use any of the nice structure from the module. What about the following:

vec(A) = reshape(A, :)

Zygote.@adjoint kronecker(A, B) = kronecker(A,B), dy-> (
  dy.A .* (vec(dy.B)' * vec(B)), dy.B .* (vec(A)' * vec(dy.A)))