Closed jlperla closed 2 years ago
@wupeifan They told us the trick for the CRTU. Use the rrule_f
argument.
using Zygote, LinearAlgebra, ChainRulesTestUtils
N = 50
A = rand(N, N)
b = rand(N)
h(A, b) = sum(A \ b)
test_rrule(Zygote.ZygoteRuleConfig(), h, A, b; rrule_f = rrule_via_ad)
It didn't work with a Distribution as an argument but I think that was because of type stability issues rather than CRTU having problems itself.
If we need to do things manually, we can always do
@test gradient(h, A,b)[1] ≈ FiniteDiff.finite_difference_gradient(h, A,b)
or something like that.
The broad goals: (1) : remove the DSSM dependency from the tests after proving there is no weird interaction in the zygote gradients; (2) ensure simple unit tests exist so that if people tweak performance it won't accidentally break anything; and (3) prepare the exact setup we want to benchmark so that @jlperla can write a simple benchmark script.
At that point, this should be self-contained for anyone to look at performance of individual peices of the code and to run on julia 1.6 and 1.7.
The biggest here is to get the unit tests decoupled from DSSM which has way too many upstream interactions and dependencies (especially things holding back Julia 1.7 testing). But I don't think we need to add in new ones yet. We just don't want other people working on performance to accidentally break functionality.
So basically we need to get rid of https://github.com/SciML/DifferenceEquations.jl/blob/main/test/dssm.jl and convert it into
first_order.jl
or something along those lines. We can add in 2nd order later after we are done with first order. Part of doing first order is to see if a custom adjoint for the linear simulations and joint-likelihood is helpful relative to a gerneric one. If not, then 2nd order will follow relatively quickly and would be implemented in a very different way.To get rid of the DSSM, then things I see as likely important are:
A, B, C
and theDiagonal(D)
,observables
and whatever necessary for RBC example into literals in a unit test. I don't think we need to save them in files because they aren't that big.u0
andnoise
so I don't think those are necessary to save. But we donj't want anything randomized in the unit tests.General function to evaluate joint likeliihoods. Can't pass in Distributions directly for unit tests
function joint_likelihood_1(A,B,C,u0,noise,observables,D) problem = LinearStateSpaceProblem(A, B, C, u0, (0,length(noise)); noise, observables, obj_noise = MvNormal(zeros(x.D), Diagonal(x.D))) return solve(problem, NoiseConditionalFilter(); save_everystep = false).loglikelihood end
@testset "linear rbc joint likelihood" begin @test joint_likelihood_1(A_rbc, B_rbc, C_rbc, u0_rbc, observables_rbc, D_rbc) ≈ WHATEVER IT IS @inferred joint_likelihood_1(x) # would this catch inference problems in the solve? test_rrule(Zygote.ZygoteRuleConfig(), joint_likelihood_1, A_rbc, B_rbc, u0_rbc, observables_rbc, D_rbc; rrule_f = rrule_via_ad) end
joint_likelihood_1
for the rbc makes sense relative to our sampling speed.rbc_kalman
and thenFVGQ joint 1
andFVGQ kalman
tests
and load them. THe easiest path is to dojointpath(pkgdir(DifferenceEquations), "tests/data/FVGQ_A.csv")
or something along those lines.@show A_rbc;
and then copy paste, and then the code formatter these might be more compact than we would expect.Otherwise, I think a few important cleanup things before we ask anyone else to look at this (because many might be on Julia 1.7) are:
@requires
in there and see if we will implement custom adjoints, but that can wait.DifferentiableStateSpaceModels
in particular, etc. in https://github.com/SciML/DifferenceEquations.jl/blob/main/test/Project.toml I don't see why we would need anything more thatTest, FiniteDiff
and the AD packages after we move out Zygote properly, but I could be missing something.