FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 210 forks source link

Complex valued AD for `TensorOperations` #151

Open GiggleLiu opened 5 years ago

GiggleLiu commented 5 years ago

Zygote.jl now goes pretty well with TensorOperations.jl, it gives correct result for real numbers without any effort. But for complex numbers, the gradient is different by a conjugate

As an example

The untouched version

using TensorOperations
using Zygote: @adjoint, gradient

a  = randn(ComplexF64, 3,3)
g(a) = (x = conj.(a); real(@tensor a[i,j]*x[i,j]))

The correct version (I believe)

# Here we used the indexing magic of einsum:
# https://giggleliu.github.io/2019/04/02/einsumbp.html
@adjoint function tensorcontract(A, IA, B, IB, IC)
    y = tensorcontract(A, IA, B, IB, IC)
    y, _C -> (tensorcontract(_C, IC, conj(B), IB, IA), nothing, tensorcontract(conj(A), IA, _C, IC, IB), nothing, nothing)
end

f(a) = tensorcontract(a, (1,2), conj(a), (1,2), ())[] |> real

f is equivalent to g, the exact gradient should be f'(a) = 2a, but g'(a) gives 2conj(a). Wondering which function @tensor is called into that gives incorrect gradient. @Jutho @under-Peter

Performance

After fixing the conjugate array broadcasting issue (https://github.com/FluxML/Zygote.jl/pull/146#issuecomment-482612224), the backward is 5 times slower than forward (two fold may be explained by differentiating over two inputs A and B)

julia> a = randn(ComplexF64, 1000, 1000)
julia> @benchmark f($a) seconds=1
BenchmarkTools.Trial: 
  memory estimate:  15.26 MiB
  allocs estimate:  3
  --------------
  minimum time:     9.373 ms (0.00% GC)
  median time:      10.425 ms (4.90% GC)
  mean time:        11.365 ms (15.28% GC)
  maximum time:     93.405 ms (89.91% GC)
  --------------
  samples:          88
  evals/sample:     1

julia> @benchmark gradient(f, $a) seconds=1
BenchmarkTools.Trial: 
  memory estimate:  106.81 MiB
  allocs estimate:  17
  --------------
  minimum time:     46.584 ms (7.04% GC)
  median time:      50.088 ms (8.95% GC)
  mean time:        57.562 ms (18.77% GC)
  maximum time:     169.280 ms (73.77% GC)
  --------------
  samples:          18
  evals/sample:     1

BTW: I notice that the undesired gradient in the output tuple is still calculated, is it possible to avoid this kind of overhead, like knowing which is needed? @MikeInnes

guochu commented 4 years ago

This kind of behavior for complex gradient should be expected, that is, for holomorphic functions such matrix multiplication, the result need to be conjugated. We just write a paper (https://arxiv.org/pdf/2003.04295.pdf) which propose a automatic differentiation algorithm for general complex loss functions which could be able to derive the correct gradient (the same result as if the complex numbers are treated as a tuple of two real numbers). For a general function, the adjoint function should be defined as in Eq.15. In this special case of a holomorphic function, the result should be conjugated. However, here I would like to point out that currently not all the functions are defined according to the rule as in the paper. For example, the adjoint of the function v -> (v x, v y) should be modified to v -> (v' x, v y). The usage of the dot function for complex case will result in errors.

GiggleLiu commented 4 years ago

@guochu Maybe you wanted to comment under this issue https://github.com/FluxML/Zygote.jl/issues/29 . We had lots of discussion about this issue, now the theory has been very clear. Maybe the AD support for complex numbers in Zygote is not complete yet, but the remaining problems are technical.

What you're writing is exactly the wirtinger's algebra. Your fomula is exactly the same as the one in Akira's book “Complex Valued Neural Networks”, and many rules in your work has already been covered in my blog. Although it is not something new, but it is very impressive to figure it out on your own!

guochu commented 4 years ago

Thanks for your reply! I read you blog and yes, I think the correct gradients as well as the complex chain rule should be very well-known. Just that the to define a reasonable adjoint function after knowing the chain rule, I guess one needs to slightly one step further (which is from Eq.10 to Eq.15 in the reference paper). At least this step is non-trivial for myself. I knew the gradient required is 2\partial f/\partial z* and also knew the complex chain rule, but when I was using Zygote with complex functions I could just not converge. And the I tried to think how to define a correct adjoint function to make it work which lead to the manuscript mentioned. It could be a reference for myself and hopefully to clarify some doubts for newcomers.

Maybe these rules are already the guideline of Zygote for complex functions, just that there exists some bugs in some functions. I would report a bug then.

Thanks again!

GiggleLiu commented 4 years ago

I knew the gradient required is 2\partial f/\partial z* and also knew the complex chain rule, but when I was using Zygote with complex functions I could just not converge.

Better to submit an issue. If I understand correctly, Eq. 10-15 are about obtaining the gradient for a complex valued loss?

guochu commented 4 years ago

Sure, I have already submitted an issue. If Zygote is designed to follow this rule then the definition of the adjoint function "dot" (which is a really frequently used function I think) is wrong. This is the reason I figured out why my program does not converge and I have to redefine it myself. I am sure there are other problems with other non-holomorphic functions.

The goal of the manuscript is not to tell how to compute the complex gradient using the chain rule, which should be well known and is the Eq.10 in the manuscript. In the reverse-mode automatic differentiation algorithm, the compute does not directly evaluate the chain rule, but will use a user-defined "adjoint function" which you must be very familiar with, such that the compute will be able to "automatically" derive the gradient of the composite function which is the same as the result from the chain rule. This manuscript tells how to defined such an adjoint function which is Eq.15 in the general case, instead of telling the explicit form of the complex gradient. In principle Eq.15 could be "straightforwardly" derived from the chain rule in Eq.10, but I have not seen it written formally anywhere else. If you know such references I will be very grateful if you point them out to me.

GiggleLiu commented 4 years ago

Sure, I have already submitted an issue.

where?

the definition of the adjoint function "dot" (which is a really frequently used function I think) is wrong.

explaining?

I am interested to know how do you comment on the backward rules defined in OMEinsum.jl ? It is similar to dot. We use that to write our TRG and CTMRG algorithms, the training works very good thankfully.

If I understand correctly, Eq. 10-15 are about obtaining the gradient for a complex valued loss?

I suppose you mean yes? You definitely should read and cite Akira's book. It is a book can be easily googled out and it covers every details in your paper. Including lines 10-15. The version in my blog is just a simplified version for real loss function.

If you know such references I will be very grateful if you point them out to me.

But I am not sure you can find the digital version of this book. I read that book from the library.

guochu commented 4 years ago

where?

Here it is. #540

explaining?

The correct version could be derived from Eq.15, which is listed in the table. The table also contains the adjoint function for matrix multiplication, which is similar to tensor contract, as you can see, a conjugate has to be taken. I looked at this page. In the Autodiff section it mentions the implementation of the gradient, I am sure this form would be problematic for complex tensors, but I am not sure if it has done the conjugation internally in case of complex tensors or not. If this package does not do conjugation and your code work well with it, I am afraid that this is either because your system has time-reversal symmetry or that you are really lucky.. If you are still not sure how to derive the adjoint function of "dot" based on Eq.15, maybe we could talk privately.

Thanks for pointing out the book, but I do not have access to it... I can only find this theisi. But I donot know whether it relates to Sec.4 of the book you mentioned. In this thesis I just found complex chain rule and didn't not find explicit rule for adjoint function of complex function..

cossio commented 3 years ago

Bump. Any updates on this?

guochu commented 3 years ago

Bump. Any updates on this?

It seems that you can no longer get the gradient of TensorOperations functions simply from the above code now.. Anyway you can still manually write the adjoint function for tensor contract which should be relatively straightforward.

Jutho commented 3 years ago

I still not have found time for this myself, but there is the integration of TensorOperations.jl in Tullio.jl which can also compute the gradients, and there is also https://github.com/mcabbott/TensorGrad.jl and https://github.com/ho-oto/TensorRules.jl which might still be functional.