FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 212 forks source link

Mutable struct bugs #1111

Open mcabbott opened 2 years ago

mcabbott commented 2 years ago

Something is currently broken about the handling of mutable structs. One place I can isolate this is the following example, where Ref(Ref(x)) leads silently to wrong answers:

gradient(x -> abs2(x.x) + 7 * real(x.x), Ref(1+im))  # works correctly, ((x = 9.0 + 2.0im,),)

gradient(x -> abs2(x[].x) + 7 * real(x[].x), Ref(Ref(1+im)))  # gives (nothing,)

That has been broken since at least 0.6.0. This version, with Array(Ref(x)), was not broken in 0.6.20, but currently gives a wrong answer:

gradient(x -> abs2(x[1].x) + 7 * real(x[1].x), [Ref(1+im)])  # wrong answer, ([(x = 16.0 + 2.0im,)],)

I presume that #1102 and/or #1103 caused this regression. There were, and are, astonishingly few tests of mutable structs. If anyone has other use cases they think ought to work, and can be compactly summarised to add to tests, that would be extremely helpful.

There are other mysterious bugs which seem like they are related to the handling of mutable structs, for which I don't have a compact example. It's likely that https://github.com/FluxML/Zygote.jl/issues/1109 is related.

Bugs involving only immutable structs are not in scope here. There are some remaining leaks of ChainRules types, every time you fix one another appears... many are fixed by https://github.com/FluxML/Zygote.jl/pull/1104

lazarusA commented 1 year ago

testing this now [Zygote v0.6.55] outputs:

gradient(x -> abs2(x.x) + 7 * real(x.x), Ref(1+im))  # works correctly, ((x = 9.0 + 2.0im,),)
gradient(x -> abs2(x[].x) + 7 * real(x[].x), Ref(Ref(1+im)))  # gives (nothing,)
((x = 9.0 + 2.0im,),)

((x = (x = 9.0 + 2.0im,),),)

it looks ok, besides the obvious extra nesting.

I'm about to go this road, using mutable struct's. Docs said, not use if possible, but life is/will be much easier using them if they work as expected 😄 .