under-Peter / OMEinsum.jl

One More Einsum for Julia! With runtime order-specification and high-level adjoints for AD
https://under-peter.github.io/OMEinsum.jl/dev/
MIT License
183 stars 23 forks source link

use @adjoint! instead of @adjoint for einsum! AD #18

Closed GiggleLiu closed 5 years ago

GiggleLiu commented 5 years ago

According to the discussion in https://github.com/under-Peter/OMEinsum.jl/pull/17#discussion_r290195796

There are some examples in Zygote source code about defining @adjoint!. Since how gradients are accumulated to the gradient tensor in mutable structures is a bit tricky, we will revisit this issue later (after dispatch).

This is the Zygote paper, https://arxiv.org/abs/1810.07951

under-Peter commented 5 years ago

In that comment, Mike says not to define adjoints for mutating functions so I think we should use regular @adjoint and just define it for einsum (without the !).

At least for the moment, I can imagine progress on Zygotes side in the next weeks

GiggleLiu commented 5 years ago

ok, I see. If I get correctly, he means neither @adjoint nor @adjoint! are safe for inplace functions currently. Which means abandoning differentiating over broadcasting operations that have to be inplace.

I would suggest this kind of design to support a full featured AD

einsum(ixs, xs, iy, y_shape) = einsum!...   # define adjoint on this, this is full featured, whereas no inplace operations.

einsum(ixs, xs, iy) = einsum(ixs, xs, iy, generate_output_shape(ixs, iy))

Cool?

under-Peter commented 5 years ago

Works for me.