FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

Default presentation of gradients confusing? #1519

Closed NAThompson closed 2 months ago

NAThompson commented 2 months ago

When using Zygote with Unitful, the presentation of the gradient is a bit confusing:

using Unitful
using Zygote

struct Foo
    x::Number
    t::Number
    c::Number
end

function bar(f::Foo)
    return f.x - f.c*f.t
end

foo = Foo(2u"m", 3u"s", 3e8u"m/s")
g = Zygote.gradient(f -> bar(f), foo)

println(g)

Output:

((x = 1.0, t = -3.0e8 m s⁻¹, c = -3.0 s),)

The numbers are correct, but shouldn't the fields be printed with (say)x\dot[TAB] rather than x?

Or perhaps:

((∂f/∂x = 1.0, ∂f/∂t = -3.0e8 m s⁻¹, ∂f/∂c = -3.0 s),)
ToucheSir commented 2 months ago

There are many reasons to avoid being too clever when printing out gradients (Zygote doesn't control showing of types it returns, ambiguities with indexed types like arrays, can't change field names of custom structs in gradients, etc), but the easiest way to demonstrate is with a counterexample. What should be printed in this scenario?

foo = (;
  x = (;
    y = 1,
    x = 2
    c = 3,  
  ),
  c = (;
    x = 2,
    z = ((; x = 2), (; x = 2))
  ),
  z = [(x = 1, c = (; x = 2, c = 3)), (x = 2, c = (; x = 3, c = 4)), ... x100]
)

f(foo) = # sums every number in the nested structure
∂f∂foo = gradient(f, foo)[1]

It's clear from this that the proposed way of printing breaks down when there is any sort of nested structure in the gradients. That's not to say it lacks pedagogical value, but the marginal benefit over

∂f∂foo = gradient(f, foo)[1]
∂f∂x, ∂f∂t, ∂f∂c = ∂f∂foo

or

∂f∂foo = gradient(f, foo)[1]
∂f∂foo.x, ∂f∂foo.t, ∂f∂foo.c # more honest about the actual structure and mirrors input argument

Is low and not worth the myriad problems/limitations it would bring.