cgarciae / treex

A Pytree Module system for Deep Learning in JAX
https://cgarciae.github.io/treex/
MIT License
215 stars 17 forks source link

Call Module.__call__ during init by default #59

Open cgarciae opened 2 years ago

cgarciae commented 2 years ago

Now that all basic layers use shape inference it would be better to call the forward method __call__ by default to avoid the problem of submodules not being initialized (or even created) after a call to init if the user forgets to pass the inputs.