CDCgov / Rt-without-renewal

https://cdcgov.github.io/Rt-without-renewal/
Apache License 2.0
17 stars 3 forks source link

Resolve gradient mismatches in benchmarks #340

Closed seabbs closed 1 month ago

seabbs commented 3 months ago

For several utilities, benchmarking suggests that different backends given different gradients. This should be investigated as it may indicate performance issues.

SamuelBrand1 commented 3 months ago

Is the difference between all options or is Zygote the outlier? Reverse diff with compiled tape is known to have these problems with logical branches in the code.

seabbs commented 3 months ago

Compiled tape has issues.

SamuelBrand1 commented 3 months ago

That makes sense... The improvement path here is track down the "wrong" branch(es) that are getting compiled.

seabbs commented 2 months ago
Screenshot 2024-07-18 at 23 12 21

This is what I see in our current benchmarks. Could really do with a improved warning message here to help localise

seabbs commented 2 months ago
Screenshot 2024-07-22 at 18 39 17

In the latest benchmarks in #392 I still see a single instance of this that needs resolving.

SamuelBrand1 commented 2 months ago

Where is this benchmark?

seabbs commented 2 months ago

I think we can also use instabilities in our benchmarks (i.e https://github.com/CDCgov/Rt-without-renewal/pull/400#issuecomment-2250988929) to indicate where problems are. It would be useful to think if there is a more formal way of checking this that requires less tracking across PRs.

(i.e here it suggests that observation error models are problematic).

seabbs commented 2 months ago

414 localised some issues to the NegativeBinomialError model. It would be good to investigate this further to see if it common across all error models or just the negative binomial.

See: https://github.com/CDCgov/Rt-without-renewal/pull/414#issuecomment-2263521390

seabbs commented 2 months ago

Running some more repititions I see the following throwing gradient issues (due to compiled reverse diff):

Type Count
Ascertainment{NegativeBinomialError} 3
NegativeBinomialError{HalfNormal{Float64}} 2
  1. Warnings from Model{typeof(generate_observations), (:obs_model, :y_t, :Y_t), (), (), Tuple{Ascertainment{NegativeBinomialError{HalfNormal{Float64}}, AbstractTuringLatentModel, var"#88#89", String}, Vector{Int64}, Vector{Int64}}, Tuple{}, DefaultContext}(EpiAware.EpiAwareBase.generate_observations, (obs_model = Ascertainment{NegativeBinomialError{HalfNormal{Float64}}, AbstractTuringLatentModel, var"#88#89", String}(NegativeBinomialError{HalfNormal{Float64}}(HalfNormal{Float64}(μ=0.01)), PrefixLatentModel{FixedIntercept{Float64}, String}(FixedIntercept{Float64}(0.1), "Ascertainment"), var"#88#89"(), "Ascertainment"), y_t = [100, 100, 100, 100, 100, 100, 100, 100, 100, 100], Y_t = [100, 100, 100, 100, 100, 100, 100, 100, 100, 100]), NamedTuple(), DefaultContext()):
    ┌ Warning: `ad.compile` where `ad` is `AutoReverseDiff` has been deprecated and will be removed in v2. Instead it is available as a compile-time constant as `AutoReverseDiff{true}` or `AutoReverseDiff{false}
  2. Warnings from Model{typeof(generate_observations), (:obs_model, :y_t, :Y_t), (), (), Tuple{Ascertainment{NegativeBinomialError{HalfNormal{Float64}}, AbstractTuringLatentModel, var"#82#83", String}, Vector{Int64}, Vector{Int64}}, Tuple{}, DefaultContext}(EpiAware.EpiAwareBase.generate_observations, (obs_model = Ascertainment{NegativeBinomialError{HalfNormal{Float64}}, AbstractTuringLatentModel, var"#82#83", String}(NegativeBinomialError{HalfNormal{Float64}}(HalfNormal{Float64}(μ=0.01)), PrefixLatentModel{FixedIntercept{Float64}, String}(FixedIntercept{Float64}(0.1), "Ascertainment"), var"#82#83"(), "Ascertainment"), y_t = [100, 100, 100, 100, 100, 100, 100, 100, 100, 100], Y_t = [100, 100, 100, 100, 100, 100, 100, 100, 100, 100]), NamedTuple(), DefaultContext()):
  3. Warnings from Model{typeof(generate_observations), (:obs_model, :y_t, :Y_t), (), (), Tuple{Ascertainment{NegativeBinomialError{HalfNormal{Float64}}, AbstractTuringLatentModel, var"#64#65", String}, Vector{Int64}, Vector{Int64}}, Tuple{}, DefaultContext}(EpiAware.EpiAwareBase.generate_observations, (obs_model = Ascertainment{NegativeBinomialError{HalfNormal{Float64}}, AbstractTuringLatentModel, var"#64#65", String}(NegativeBinomialError{HalfNormal{Float64}}(HalfNormal{Float64}(μ=0.01)), PrefixLatentModel{FixedIntercept{Float64}, String}(FixedIntercept{Float64}(0.1), "Ascertainment"), var"#64#65"(), "Ascertainment"), y_t = [100, 100, 100, 100, 100, 100, 100, 100, 100, 100], Y_t = [100, 100, 100, 100, 100, 100, 100, 100, 100, 100]), NamedTuple(), DefaultContext())
  4. Warnings from Model{typeof(generate_observations), (:obs_model, :y_t, :Y_t), (), (), Tuple{NegativeBinomialError{HalfNormal{Float64}}, Vector{Float64}, Vector{Float64}}, Tuple{}, DefaultContext}(EpiAware.EpiAwareBase.generate_observations, (obs_model = NegativeBinomialError{HalfNormal{Float64}}(HalfNormal{Float64}(μ=0.01)), y_t = [10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0], Y_t = [10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0]), NamedTuple(), DefaultContext())
  5. arnings from Model{typeof(generate_observations), (:obs_model, :y_t, :Y_t), (), (), Tuple{NegativeBinomialError{HalfNormal{Float64}}, Vector{Float64}, Vector{Float64}}, Tuple{}, DefaultContext}(EpiAware.EpiAwareBase.generate_observations, (obs_model = NegativeBinomialError{HalfNormal{Float64}}(HalfNormal{Float64}(μ=0.01)), y_t = [10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0], Y_t = [10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0]), NamedTuple(), DefaultContext())

Something of a pattern I think!

seabbs commented 1 month ago

I think with #442 and #415 these are now all handled so closing.