patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.44k stars 130 forks source link

Neural SDE example does not run #479

Closed BenjaminJCox closed 3 months ago

BenjaminJCox commented 3 months ago

Hello, and apologies in advance for the formatting of this issue.

I am attempting to run the Neural SDE example in colab (after installing the requisite packages), and as written it does not run.

It first throws the error:

 21 
 22     def __call__(self, t, y, args):

---> 23 return self.scale * self.mlp(jnp.concatenate([t[None], y])) 24 25

TypeError: 'float' object is not subscriptable (colab producing this error here: https://colab.research.google.com/gist/BenjaminJCox/c00701b3992054dc2e31ec0a86b32244/neural_sde.ipynb)

Upon removing the indexing from the "t" terms referenced in the error we recieve this error, which is rather more relating to this library

1026 y0, args, terms, solver.term_structure, solver.term_compatible_contr_kwargs 1027 ): -> 1028 raise ValueError( 1029 "terms must be a PyTree of AbstractTerms (such as ODETerm), with " 1030 f"structure {solver.term_structure}"

ValueError: terms must be a PyTree of AbstractTerms (such as ODETerm), with structure <class 'diffrax._term.AbstractTerm'>`

(colab producing this error here: https://colab.research.google.com/gist/BenjaminJCox/cbb19388b655b58a9c99764754db8a1a/neural_sde.ipynb)

By my eye the error makes no sense, as in both the SDE and the CDE vf and cvf are an ODETerm and a ControlTerm respectively, with MultiTermcombining them. I think that maybe something has changes in the API since the example was last updated.

I believe that having an up to date set of examples is very valuable, and having these example to point my researhc group at would be very helpful in the future. Unfortunately I can't see why the error is occuring, although if you have any ideas I will gladly try them and let you know.

BenjaminJCox commented 3 months ago

A similar example that runs fine (from the introductory notes). The terms are constructed in a similar way (both use MultiTerm, ODETerm and ControlTerm, althoug of course the neural example is more complex)

notebook here: https://colab.research.google.com/gist/BenjaminJCox/1bdeafb761d4a851c152e933cf25c2f5/untitled3.ipynb

patrick-kidger commented 3 months ago

Thanks for the report! The example should be fixed in https://github.com/patrick-kidger/diffrax/pull/480. The docs will update shortly.

It just so happened that in this example, we always called it with an array. However in general we call the vector field with an arraylike -- e.g. a raw Python float -- and in more recent releases of Diffrax we now actually do that.

BenjaminJCox commented 3 months ago

Thank you for getting back so quickly, that seems to have fixed it!