poets-ai / elegy

A High Level API for Deep Learning in JAX
https://poets-ai.github.io/elegy/
MIT License
469 stars 32 forks source link

Random state within a Loss #215

Open cgarciae opened 2 years ago

cgarciae commented 2 years ago

Losses currently cannot use random state, while its not that common its sometime needed and it should be supported. There are two ways of going about this:

  1. Let Losses optionally accept a next_key: KeySeq argument, no modifications to Loss needed but now Model need to have a next_key: KeySeq field (not that bad).
  2. Let Loss inherit from treex.Treex or even just treeo.Tree so the user can just create its own random state. Downside is that losses are potentially stateful and we would need watch out for the implications when used inside a Metric e.g. Losses and friends.