Closed jlperla closed 2 years ago
Just a reminder: the mean zero MvNormal is a special type we can discpatch on as required or error if it isn't fulfilled. So we need to make sure we use those. That is note
julia> MvNormal([1.0 0.1; 0.1 1.0]) |> typeof
ZeroMeanFullNormal{Tuple{Base.OneTo{Int64}}} (alias for MvNormal{Float64, PDMats.PDMat{Float64, Array{Float64, 2}}, FillArrays.Zeros{Float64, 1, Tuple{Base.OneTo{Int64}}}})
vs.
julia> MvNormal([0.0, 0.0], [1.0 0.1; 0.1 1.0]) |> typeof
FullNormal (alias for MvNormal{Float64, PDMats.PDMat{Float64, Array{Float64, 2}}, Array{Float64, 1}})
The above Tangent{typeof(dist)}(;Σ = dΣ)
type should take care of that so we don't need to worry about actually specifying that type though.
What I had might work, but it isn't that clear. The covariance matrix is always a cholesky decompostion. See the following for more on how to work with the tangent types to the cholesky, which is the natural thing to work with in some cases:
using LinearAlgebra, Distributions, LinearAlgebra, PDMats
using ChainRules: Tangent
# Could get a cholesky diectly,
# Σ_raw = [0.1 0.05; 0.05 0.1]
# C_fact = cholesky(Σ_raw) # if we are using choleskys for example
# Raw upper triangular
C = UpperTriangular([0.316228 0.158114; 0 0.273861])
dist = MvNormal(PDMat(Cholesky(C)))
dist.Σ.chol.U # If we ever need to access this in the cholesky value, this is the upper triangular part for example
# For the tangent type,
dChol = UpperTriangular([0.01 0.1; 0 0.4])
dΣ = Tangent{typeof(dist.Σ)}(; chol = dChol, mat = dChol' * dChol) # or something like that? Is the `mat` even necessary?
#Then the tangent type for the whole distribution is
dDist = Tangent{typeof(dist)}(; Σ = dΣ)
# To access it in the rules,
dDist.Σ.chol.U # etc.
@jlperla Tangent{typeof(dist)}(;Σ = dΣ)
doesn't work because it needs a constructor for PDMat
, and a type error would emerge.
If we have to construct directly from Cholesky...
Can you try to see if something like the following works?
using Distributions, ChainRules, LinearAlgebra, PDMats
using ChainRules: Tangent
# Whatever you have for the change coming upstream
#dist = MvNormal(PDMat(Cholesky(UpperTriangular([0.316228 0.158114; 0 0.273861])))) # This isn't strictly necessary, could just use covariance, but better to use chol if you have it.
dist = MvNormal([1.0 0.1; 0.1 1.0])
# OK, now inside of the custom rrule with the "dist" passed in as the argument
dΣ = [0.1 0.01; 0.01 0.1]
dΣ_mat = Tangent{typeof(dist.Σ)}(; mat = dΣ)
dDist = Tangent{typeof(dist)}(; Σ = dΣ_mat)
Note that what I did there was take the matrix change and create the PDMat Tangent type with it but without the "chol". This means that you would need to access the dDist.Σ
and maybe dDist.Σ.mat
for the change in the covariance matrix and you cannot use the chol
part of it. But I think you have the code using the matrices right now anyways, right?
MvNormal
instead of a turing type in https://github.com/HighDimensionalEconLab/DifferentiableStateSpaceModels.jl/blob/main/src/types.jl#L377-L378 I think this should probably now be(or something like that just to multiply the cholesky through. We can go back to see if the cholesky is useful internally a different time.
Note that this is a specialized mean-zero type of the
MvNormal
probl.u0.Σ
I believe? A symmetric variance matrix.i.e., if you have the
dΣ
calculated then it is easy enough to backprop. No need to do anything on the mean since we construct with the zero-mean type and it should go back up the chain to the DifferentiableStateSpaceModels.c.V
as a matrix rather than cholesky, then that is another easy solution which would have few effects. The only issue is that you would need to rethink theV_p
stuff because I think it was convenient as a cholesky. But it doesn't matter... this is not a significant part of any calculation.I should point out that the risk here is that there is some sort of weird "promotion" that happens with the chainrules of a
MvNormal
to aTuringMvNormal
behind the scenes.... But I just don't see how that could happen in that callchain with a custom AD rule.