Open oxinabox opened 3 years ago
This feels to me like one of the peculiarities of using a Symmetric
matrix to represent the (co)tangent of a Symmetric
matrix. It's also where the semantics of to_vec
and what we do with AD disagree, and is one of the things that we need to fix more generally.
It relates to https://github.com/JuliaDiff/FiniteDifferences.jl/issues/132 I believe.
If we pretend for a minute that the Symmetric
matrix is a Tangent
with a parent
field, it would be correct. sum
does indeed access an off-diagonal element of the input matrix on the forwards-pass twice, so a correct gradient for this operation would be
Tangent{Symmetric}(parent=[1 0; 2 1])
(or something like that).
I think the problem comes in when we convert between representations. In particular, we don't use a Tangent
, we use another Symmetric
, which interprets the parent
field as a matrix, and therefore gives us the Symmetric matrix you're seeing. This is the essence of what to_vec
is doing here, since to_vec
doesn't know anything about the distinction between primals and tangents because it was devised before we really knew what we were doing there.
To convert between the strutural and natural cotangent for Symmetric
, I think one needs to do some work. Specifically, halving the strict lower triangle of the parent field, and using it to construct a Symmetric
. Similarly, to convert back you would need to double the strict lower triangle.
Does that make sense as an explanation for what is going on (even if it doesn't provide an obvious solution)?
edit: I suspect giving the parent
field rather than collecting the matrix would resolve this particular problem, but other stuff from #132 then becomes relevant.
I wonder if it would just work if one uses the same approach as in https://github.com/JuliaDiff/FiniteDifferences.jl/pull/146 - similar to Diagonal
, Symmetric
is just a representation of a specific lower-dimensional manifold of the space of all matrices, so it seems a bit weird to go via Matrix
(in both cases). Even if it would not fix the issue here, it would feel "more correct" to me.
I wonder if it would just work if one uses the same approach as in #146 - similar to Diagonal, Symmetric is just a representation of a specific lower-dimensional manifold of the space of all matrices, so it seems a bit weird to go via Matrix (in both cases). Even if it would not fix the issue here, it would feel "more correct" to me.
Right, and that would be realtively simple, in that we would fix our Triangular matrixes to do that (right now they also go via Matrix
) and then would use the appropriate one of them to get the vector.
I agree with both of you, but I would direct you towards #132 -- we'll be changing semantics by doing this. I'm fine with making this change, but it is a pretty fundamental change / makes something that we already kind of do more explicit.
Although it is a fundamental change and probably breaks some things, I think one should change to_vec
as suggested in https://github.com/JuliaDiff/FiniteDifferences.jl/issues/132.
I would expect to_vec(x)
to return a canonical representation of the object x
in R^n, and it seems reasonable to choose n as the intrinsic dimension of x
(e.g. m in case of a m x m Diagonal
matrix). I also think it's reasonable that different objects, such as (1, 3)
and [1, 3]
, have the same canonical representation. The function back
returnd from to_vec
contains the recipe for reconstructing the object, and hence I think equality-wise one should only require x == y <=> |>(to_vec(x)...) == |>(to_vec(y)...)
.
This makes me wonder if there is a fundamental difference between ParameterHandling.flatten
and to_vec
, or if they could be unified in some way.
Glad we all agree -- doing this should simplify a number of implementations as well, which is nice.
Consider: we know that the derviative of any
sum
of a collection of ones similar to the input. because it issum(xs) = x[1] + x[2] + ...
.However for
Symmetric
it is giving2
for the anti-diagonal on a 2x2 matrix.Similar issues occur for
prod
which came up in https://github.com/JuliaDiff/ChainRules.jl/pull/335/filesSomething must be wrong with out we are defining
to_vec
. https://github.com/JuliaDiff/FiniteDifferences.jl/blob/266d6faa8039c382d3fbe80f8ef0b91f6a09726c/src/to_vec.jl#L90-L96