Open cgarciae opened 2 years ago
cc: @anvelezec @ptigwe @Davidnet @samuela @lkhphuc @jmarrietar
I like the proposal and excited with the upcoming changes.
I think Immutability and therefore apply(*, rng)
are the right move forward.
Jax and its ecosystem is is highly compositional by nature of functional programming, so inevitably I will need to use a tool provided from another Jax-based library.
Been training model using elegy for a while, I have created various bugs
around the interface of implicit stateful, which I think I would not make with the immutable
approach.
Model will now be a regular non-pytree Python object that would contain a state: ElegyModule field that it would maintain and update inplace.
Should it be called a Trainer
then, like PytorchLightning?
I don't use the current loss and metric module as I mostly use low-level API. A separate jax-based library ideally would provide a functional API as well, so that could be used in the low-level API as well as in barebone or other frameworks.
A bit related is a first-class support for probabilistic programming in elegy as well, i.e Distrax
. I have increasingly replace my loss functions with like p_x.log_prob(x)
etc.
Currently it already mostly work thanks to the Jax-based functional design. However there is still some sharp edges, like cannot use with .summary()
method, etc.
@lkhphuc After a bit of work, #70 is passing. This reworks all of Treex abstraction to adopt an immutable API as proposed here. Here is an example using apply
:
This applies for the Optimizer
as well:
It wasn't in this proposal but, when train_step
, test_step
, and friends are used as methods of a Module/Tree, then it become a little cumbersome and error prone to use .replace
everywhere to update self
. Instead I experimented with this toplevel_mutable
decorator which in this next example would create a copy of and temporarily make self
mutable while keeping all subtrees immutable:
From the outside this pattern looks really nice:
Similar to mutable
, methods decorated with toplevel_mutable
don't modify the original object so the API is kept immutable from an outside perspective.
Hello, I just found this awesome library so my opinion is probably not very important, but here are my two cents:
Now if your tree is immutable you would use mutable which let you run this method but the update are capture in a new instance which is returned along with the output of the method:
output, tree = tree.mutable(method="acc_sum")(x)
Alternatively you could also use it as a function transformation via treeo.mutable like this:
output, tree = treeo.mutable(tree.acc_sum)(tree, x)
I think it makes more sense to call the method mutate
instead of mutable
. Generally, methods should be named after verbs. Please consider changing the name before you make a release!
Should it be called a
Trainer
then, like PytorchLightning?
Yeah, I was confused for a moment by the name elegy.Model
since machine learning "models" typically aren't bundled with loss, metrics, and optimizer. elegy.Trainer
sounds like a great name.
The creation of an
ElegyModule
class (analogous to theLightningModule
) that would centralize all the JAX-related parts of the training process. More specifically it would be a Pytree and would expose a framework agnostic API, this means Treeo's Kind system would not be used now.
Could you elaborate on why Treeo's kind system will no longer work? As you said, ElegyModule
is just a PyTree, which I assume treeo can work seamlessly with.
Hey @nalzok thanks for taking the time to write this, opinions of any kind are welcomed! This comment will also serve as an update of how implementation evolved:
I think it makes more sense to call the method mutate instead of mutable
Given the proposal also had an apply
method, ultimately it was simpler to have a mutable: bool
argument in apply
which by default is True
so previous example look identical with apply
.
Yeah, I was confused for a moment by the name elegy.Model since machine learning "models" typically aren't bundled with loss, metrics, and optimizer. elegy.Trainer sounds like a great name.
I too like the name Trainer
, however I am hesitant to make the change since it will break code that just uses the high-level API. Maybe we could rename it to Trainer
and have Model
as an alias for backward compatibility.
Could you elaborate on why Treeo's kind system will no longer work? As you said, ElegyModule is just a PyTree, which I assume treeo can work seamlessly with.
The thing is that Treeo Kind's are additional metadata that is added to the pytree leaves in order create more powerful filters, this mirrored Flax's collections. While they simplified parts of the implementation a lot, users have to learn this additional framework. The solution is to have regular pytree and have the user override a couple of additional methods (this can be automated for specific frameworks.
This is currently being implemented in poets-ai/elegy#232, here is an update the resulting APIs:
API | Methods | Description |
---|---|---|
Core API | train_step , test_step , pred_step , init_step |
User has full control, max flexibility, no logging or distributed strategies for free. |
Managed API | managed_train_step , managed_test_step , managed_pred_step , managed_init_step |
Similar to Pytorch Lightning thus sufficiently flexible, gets logging and distributed strategies, has to define methods that specify how to get/set parameters and batch statistics. |
High Level API | init , apply |
User just specifies how to perform initialization and forward pass, gets get losses, metrics, distributed strategies, and logging for free, has to define methods that specify how to get/set parameters and batch statistics. Note: This API is mostly used to simply the creation of framework-specify implementations (flax, haiku, etc), not clear if it should be exposed to users. |
Given the proposal also had an
apply
method, ultimately it was simpler to have amutable: bool
argument inapply
which by default isTrue
so previous example look identical withapply
.
Cool. The name apply(mutable=...)
is also consistent with Flax's Module.apply
, just without the variables
parameter. This will make things easier for those who have some experience with Flax.
Maybe we could rename it to
Trainer
and haveModel
as an alias for backward compatibility.
Yes, please. We can also emit a DeprecationWarning
when the name Model
is used, so that we can remove that name in a future major release.
This is currently being implemented in https://github.com/poets-ai/elegy/pull/232, here is an update the resulting APIs:
I see. Currently I don't quite understand how these work due to the lack of API documentation, but hopefully we can have some detailed documentation after things stabilize a little bit.
Regarding the documentation, do you think we should deprecate re-exporting the API for Treeo and Treex, or at least discourage users from using the re-exported APIs? Just like Keras doesn't re-export the API of Tenserflow, and users still need to import tensorflow
when using Keras. More concretely, I am suggesting changing the example
import jax
import optax
import elegy as eg
class MLP(eg.Module):
@eg.compact
def __call__(self, x):
x = eg.Linear(300)(x)
x = jax.nn.relu(x)
x = eg.Linear(10)(x)
return x
model = eg.Model(
module=MLP(),
loss=[
eg.losses.Crossentropy(),
eg.regularizers.L2(l=1e-5),
],
metrics=eg.metrics.Accuracy(),
optimizer=optax.rmsprop(1e-3),
)
to something like
import jax
import jax_metrics as jm
import optax
import treeo as to
import treex as tx
import elegy as eg
class MLP(tx.Module):
@to.compact
def __call__(self, x):
x = tx.Linear(300)(x)
x = jax.nn.relu(x)
x = tx.Linear(10)(x)
return x
model = eg.Model(
module=MLP(),
loss=[
jm.losses.Crossentropy(),
jm.regularizers.L2(l=1e-5),
],
metrics=jm.metrics.Accuracy(),
optimizer=optax.rmsprop(1e-3),
)
This way, users will have an easier time finding the document for a lower-level API in the corresponding package, and we don't need to duplicate the same documentation in several places.
More importantly, it can help clarify the "level" of an API, e.g. eg.Model
is higher-level compared to tx.Module
and jm.metrics.Accuracy
since it's the overarching trainer. While re-exporting the API from the dependencies can make things more convenient because users don't need to remember which function/package comes from which package, I'm afraid it might cause conceptual confusion in the long run because users will adopt a flattened mental model (elegy, treex, treeo, optax, jax_metrics)
instead of understanding the hierarchical structure of (elegy(treex(treeo), optax, jax_metrics))
.
Regarding the documentation, do you think we should deprecate re-exporting the API for Treeo and Treex, or at least discourage users from using the re-exported APIs?
Yes, definitely. If possible I want to make Treex an optional dependency, I want Elegy to embrace the "Framework Agnostic" slogan for real. Concretely elegy.Module
will not be treex.Module
so people will have to import treex
or whatever framework they want to use.
The question for jax_metrics
is interesting, should we re-export losses
, metrics
, and regularizers
? I am inclined to say yes.
Ah yes, I think it's fine to re-export jax_metrics
since the functions live in some submodules, i.e. we have jm.losses.Crossentropy()
instead of jm.Crossentropy()
. (I would say jm.losses.CrossEntropy()
is a better name though, otherwise the naming convention isn't really consistent)
Here are some ideas for the Treeo, Treex, and Elegy libraries which hopefully add some quality-of-life improvements so they can stand the test of time a bit better.
Immutability
Treeo/Treex has adopted a mutable/stateful design in favor of simplicity. While careful propagation of the mutated state inside jitted functions guarantees an overall immutable behaviour thanks to pytree cloning, there are some downsides:
Proposal
Add an
Immutable
mixin in Treeo and have Treex use it for its baseTreex
class, this work already started in cgarciae/treeo#13 and will do the following:__setattr__
by raising aRuntimeError
when a field being updated.replace(**kwargs) -> Tree
methods that let you replace the values for desired fields but returns a new object.mutable(method="__call__")(*args, **kwargs) -> (output, Tree)
method that lets call another method that includes mutable operations in an immutable fashion.Creating an immutable Tree via the Immutable mixing would look like this:
Additionally Treeo could also expose an
ImmutableTree
class so if users are not comfortable with mixins they could do it like this:Examples
Field updates
Mutably you would update a field like this:
Whereas in the immutable version you use
replace
and get a newtree
:Stateful Methods
Now if your Tree class had some stateful method such as:
Mutably you could simply use it like this:
Now if your tree is immutable you would use
mutable
which let you run this method but the update are capture in a new instance which is returned along with the output of the method:Alternatively you could also use it as a function transformation via
treeo.mutable
like this:Random State
Treex's
Module
s currently treat random state simply as internal state, because its hidden its actually a bit more difficult to reason about and can cause a variety of issues such as:Proposal
Remove the
Rng
kind and create anapply
method similar (but simpler) to Flax'sapply
with the following signature:As you see this method accepts an optional
key
as its first argument and then just the*args
and**kwargs
for the function. Regular usage would change from:to
However, if the module is stateless and doesn't require RNG state you can still call the module directly.
Losses and Metrics
Current Losses and Metrics in Treex (which actually come from Elegy) are great! Since losses and metrics are mostly just Pytree with simple state, it would be nice if one could extract them into their own library and with some minor refactoring build a framework independent losses and metrics library that could be used by anyone in the JAX ecosystem. We could eventually create a library called
jax_tools
(or something) that contains utilities such as aLoss
andMetric
interface + implementations of common losses and metrics, and maybe other utilities.As for the Metric API, I was recently looking a the clu from the Flax team and found some nice ideas that could make the implementation of distributed code simpler.
Proposal
Make
Metic
immutable and update its API to:Very similar to the Keras API with the exception of the
aggregate
method which is incredibly useful when syncing devices on a distributed setup.Elegy Model
Nothing concrete for the moment, but looking thinking Pytorch Lightning-like architecture which would have the following properties:
ElegyModule
class (analogous to theLightningModule
) that would centralize all the JAX-related parts of the training process. More specifically it would be a Pytree and would expose a framework agnostic API, this means Treeo's Kind system would not be used now.Model
will now be a regular non-pytree Python object that would contain astate: ElegyModule
field that it would maintain and update inplace.