Closed NAThompson closed 2 months ago
There are many reasons to avoid being too clever when printing out gradients (Zygote doesn't control showing of types it returns, ambiguities with indexed types like arrays, can't change field names of custom structs in gradients, etc), but the easiest way to demonstrate is with a counterexample. What should be printed in this scenario?
foo = (;
x = (;
y = 1,
x = 2
c = 3,
),
c = (;
x = 2,
z = ((; x = 2), (; x = 2))
),
z = [(x = 1, c = (; x = 2, c = 3)), (x = 2, c = (; x = 3, c = 4)), ... x100]
)
f(foo) = # sums every number in the nested structure
∂f∂foo = gradient(f, foo)[1]
It's clear from this that the proposed way of printing breaks down when there is any sort of nested structure in the gradients. That's not to say it lacks pedagogical value, but the marginal benefit over
∂f∂foo = gradient(f, foo)[1]
∂f∂x, ∂f∂t, ∂f∂c = ∂f∂foo
or
∂f∂foo = gradient(f, foo)[1]
∂f∂foo.x, ∂f∂foo.t, ∂f∂foo.c # more honest about the actual structure and mirrors input argument
Is low and not worth the myriad problems/limitations it would bring.
When using Zygote with Unitful, the presentation of the gradient is a bit confusing:
Output:
The numbers are correct, but shouldn't the fields be printed with (say)
x\dot[TAB]
rather thanx
?Or perhaps: