apple / axlearn

An Extensible Deep Learning Library
Apache License 2.0
1.86k stars 259 forks source link

Introduce @nowrap annotaion. #797

Closed ds-hwang closed 5 days ago

ds-hwang commented 5 days ago

It's similar to Flax @nowarp annotaion. https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/decorators.html#flax.linen.nowrap

Marks the specified module method as one that doesn't need to be wrapped.

Methods decorated with @nowrap are helper methods that don't require wrapping, and _methods_to_wrap_for_auto_child_context() will not return them.

This is especially useful in cases where a public method (i.e., one that is not explicitly prefixed with _) does not need an invocation context, such as methods that do not attempt to access state or PRNG keys.

For instance::

        >>> from axlearn.common import module
        >>> class Foo(module.Module):
        ...   @module.nowrap
        ...   def init_states(self, batch_size: int):
        ...     return dict(time_step=jnp.zeros(batch_size, dtype=jnp.int32))
ds-hwang commented 5 days ago

Could you take a look? from #861