CDCgov / PyRenew

Python package for multi-signal Bayesian renewal modeling with JAX and NumPyro.
https://cdcgov.github.io/PyRenew/
Apache License 2.0
14 stars 2 forks source link

Incorrect convolve mode in `hospitaladmissions.py` #385

Closed damonbayer closed 2 months ago

damonbayer commented 2 months ago

We should use mode="valid" to ensure the output reflects the entire infection_to_admission_interval being used.

https://github.com/CDCgov/multisignal-epi-inference/blob/main/model/src/pyrenew/latent/hospitaladmissions.py#L189-L193

I believe these lines should be replaced with

        latent_hospital_admissions = jnp.convolve(
            latent_hospital_admissions_raw,
            infection_to_admission_interval.value,
            mode="valid",
        )
dylanhmorris commented 2 months ago

I think the previous version was to allow the kind of behavior we've now decided to forbid. But I agree with forbidding it, at least by default.

damonbayer commented 2 months ago

The issue is actually more with the post-convolve indexing, but this change makes it easier to avoid that error.

sbidari commented 2 months ago

The issue is actually more with the post-convolve indexing, but this change makes it easier to avoid that error.

Can you point me to the issue/error that is being referenced here for my knowledge? @damonbayer @dylanhmorris

damonbayer commented 2 months ago

With mode="full" you will get results that implicitly pad the arrays with with 0's. When using mode="valid", this does not happen.

With mode="full", you then have to slice the head and the tail of the resulting array to get the desired output. With mode="valid" you only need the tail.

sbidari commented 2 months ago

With mode="full", you then have to slice the head and the tail of the resulting array to get the desired output. With mode="valid" you only need the tail.

The current implementation only includes slicing one end of the array resulting from convolution.

Using mode = "valid" results in latent_hospital_admissions array smaller than desired shape i.e latent_hospital_admissions_raw.shape[0]. Did you want to say mode = "same"?

damonbayer commented 2 months ago

Using mode = "valid" results in latent_hospital_admissions array smaller than desired shape i.e latent_hospital_admissions_raw.shape[0]. Did you want to say mode = "same"?

No, the "valid" mode is correct. The desired shape in the end should be the length of the observed admissions. We can do that here or later on in the "model" code.

sbidari commented 2 months ago

The desired shape in the end should be the length of the observed admissions. We can do that here or later on in the "model" code.

Right! We want the final latent_hospital_admission same length as observed_hosp_admissions. But convolve mode = "valid" yields latent_hospital_admissions array with length smaller than observed_hosp_admissions.

damonbayer commented 2 months ago

But convolve mode = "valid" yields latent_hospital_admissions array with length smaller than observed_hosp_admissions.

What are the lengths of the relevant quantities?

sbidari commented 2 months ago

But convolve mode = "valid" yields latent_hospital_admissions array with length smaller than observed_hosp_admissions.

What are the lengths of the relevant quantities?

      latent_hospital_admissions = jnp.convolve(
          latent_hospital_admissions_raw,
          infection_to_admission_interval.value,
          mode="full",
      )[:latent_hospital_admissions_raw.shape[0]]
latent_hospital_admissions.shape 
# (105, )
      latent_hospital_admissions = jnp.convolve(
          latent_hospital_admissions_raw,
          infection_to_admission_interval.value,
          mode="valid",
      )
latent_hospital_admissions.shape 
# (51, )

observed_hosp_admissions.shape
# (90, )