tedwards2412 / ripple

Differentiable Gravitational Waveforms with JAX
53 stars 15 forks source link

Update typing to use jaxtyping #24

Open kazewong opened 4 months ago

kazewong commented 4 months ago

Currently, the typing system does not leverage jaxtyping https://docs.kidger.site/jaxtyping/, which helps with static analysis.

For example, in the current typing scheme, a function could look something like

def f(x: Array) -> Array:
   ...

With jaxtyping, the function should look like:

def f(x: Float[Array, " n_samples"]) -> Float[Array, " n_samples"]:
   ...

which provides more information to the user as well. I think this is better considering the shape of array is usually very important.