SciML / DifferenceEquations.jl

Solving difference equations with DifferenceEquations.jl and the SciML ecosystem.
MIT License
32 stars 6 forks source link

Clean up unit tests for benchmarkable code for 1st order #24

Closed jlperla closed 2 years ago

jlperla commented 2 years ago

The broad goals: (1) : remove the DSSM dependency from the tests after proving there is no weird interaction in the zygote gradients; (2) ensure simple unit tests exist so that if people tweak performance it won't accidentally break anything; and (3) prepare the exact setup we want to benchmark so that @jlperla can write a simple benchmark script.

At that point, this should be self-contained for anyone to look at performance of individual peices of the code and to run on julia 1.6 and 1.7.

The biggest here is to get the unit tests decoupled from DSSM which has way too many upstream interactions and dependencies (especially things holding back Julia 1.7 testing). But I don't think we need to add in new ones yet. We just don't want other people working on performance to accidentally break functionality.

So basically we need to get rid of https://github.com/SciML/DifferenceEquations.jl/blob/main/test/dssm.jl and convert it into first_order.jl or something along those lines. We can add in 2nd order later after we are done with first order. Part of doing first order is to see if a custom adjoint for the linear simulations and joint-likelihood is helpful relative to a gerneric one. If not, then 2nd order will follow relatively quickly and would be implemented in a very different way.

To get rid of the DSSM, then things I see as likely important are:

General function to evaluate joint likeliihoods. Can't pass in Distributions directly for unit tests

function joint_likelihood_1(A,B,C,u0,noise,observables,D) problem = LinearStateSpaceProblem(A, B, C, u0, (0,length(noise)); noise, observables, obj_noise = MvNormal(zeros(x.D), Diagonal(x.D))) return solve(problem, NoiseConditionalFilter(); save_everystep = false).loglikelihood end

@testset "linear rbc joint likelihood" begin @test joint_likelihood_1(A_rbc, B_rbc, C_rbc, u0_rbc, observables_rbc, D_rbc) ≈ WHATEVER IT IS @inferred joint_likelihood_1(x) # would this catch inference problems in the solve? test_rrule(Zygote.ZygoteRuleConfig(), joint_likelihood_1, A_rbc, B_rbc, u0_rbc, observables_rbc, D_rbc; rrule_f = rrule_via_ad) end

- [ ] Next we would want to make this work to find the gradient for a mock version of our likelihood.  See the above code as well.
- [x] Later we can be mroe careful on the inference issues of the solv itself and the constructor.  Maybe code like the following added to the above one, but also might not be necessary.
```julia
    problem = LinearStateSpaceProblem(x.A, x.B, x.C, x.u0, (0,length(x.noise)); noise = x.noise, obj_noise = MvNormal(SoMETHING WITH x.D), observables = x.observables)
    @ inferred LinearStateSpaceProblem(x.A, x.B, x.C, x.u0, (0,length(x.noise)); noise = x.noise, obj_noise = MvNormal(SoMETHING WITH x.D), observables = x.observables)
    @inferred solve(problem, NoiseConditionalFilter(); save_everystep = false)

Otherwise, I think a few important cleanup things before we ask anyone else to look at this (because many might be on Julia 1.7) are:

jlperla commented 2 years ago

@wupeifan They told us the trick for the CRTU. Use the rrule_f argument.

using Zygote, LinearAlgebra, ChainRulesTestUtils
N = 50
A = rand(N, N)
b = rand(N)
h(A, b) = sum(A \ b)
test_rrule(Zygote.ZygoteRuleConfig(), h, A, b; rrule_f = rrule_via_ad)

It didn't work with a Distribution as an argument but I think that was because of type stability issues rather than CRTU having problems itself.

If we need to do things manually, we can always do

@test gradient(h, A,b)[1] ≈ FiniteDiff.finite_difference_gradient(h, A,b)

or something like that.