ASEM000 / pytreeclass

Visualize, create, and operate on pytrees in the most intuitive way possible.
https://pytreeclass.rtfd.io/en/latest
Apache License 2.0
42 stars 2 forks source link

Nested mutations using .at['method_key'] #80

Closed ASEM000 closed 1 year ago

ASEM000 commented 1 year ago

Motivation:

Experiment with lazy initialization, where inner modules initialize their params based on input.

Example

import pytreeclass as pytc
import jax.random as jr
from typing import Any
import jax
import jax.numpy as jnp

@pytc.autoinit
class LazyLinear(pytc.TreeClass):
    out_features: int

    def param(self, name: str, value: Any):
        if name not in vars(self):
            setattr(self, name, value)
        return vars(self)[name]

    def __call__(self, x: jax.Array, *, key: jr.KeyArray = jr.PRNGKey(0)):
        in_features = self.param("in_features", x.shape[-1])
        weight = self.param("weight", jnp.ones((in_features, self.out_features)))
        bias = self.param("bias", jnp.zeros((self.out_features,)))
        return x @ weight + bias

@pytc.autoinit
class StackedLinear(pytc.TreeClass):
    l1: LazyLinear = LazyLinear(10)
    l2: LazyLinear = LazyLinear(10)

    def call(self, x: jax.Array):
        return self.l2(self.l1(x))

l = StackedLinear()
print(repr(l))
# StackedLinear(l1=LazyLinear(out_features=10), l2=LazyLinear(out_features=10))

_, ll = l.at["call"](jnp.ones((1, 5)))
ll
# StackedLinear(
#   l1=LazyLinear(
#     out_features=10, 
#     in_features=5, 
#     weight=f32[5,10](μ=1.00, σ=0.00, ∈[1.00,1.00]), 
#     bias=f32[10](μ=0.00, σ=0.00, ∈[0.00,0.00])
#   ), 
#   l2=LazyLinear(
#     out_features=10, 
#     in_features=10, 
#     weight=f32[10,10](μ=1.00, σ=0.00, ∈[1.00,1.00]), 
#     bias=f32[10](μ=0.00, σ=0.00, ∈[0.00,0.00])
#   )
# )
codecov[bot] commented 1 year ago

Codecov Report

Patch coverage: 100.00% and no project coverage change.

Comparison is base (23315f6) 98.92% compared to head (8c07c79) 98.93%.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #80 +/- ## ======================================= Coverage 98.92% 98.93% ======================================= Files 15 15 Lines 2697 2715 +18 ======================================= + Hits 2668 2686 +18 Misses 29 29 ``` | [Files Changed](https://app.codecov.io/gh/ASEM000/PyTreeClass/pull/80?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Mahmoud+Asem) | Coverage Δ | | |---|---|---| | [pytreeclass/\_src/tree\_base.py](https://app.codecov.io/gh/ASEM000/PyTreeClass/pull/80?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Mahmoud+Asem#diff-cHl0cmVlY2xhc3MvX3NyYy90cmVlX2Jhc2UucHk=) | `100.00% <100.00%> (ø)` | | | [tests/test\_treeclass.py](https://app.codecov.io/gh/ASEM000/PyTreeClass/pull/80?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=Mahmoud+Asem#diff-dGVzdHMvdGVzdF90cmVlY2xhc3MucHk=) | `98.96% <100.00%> (+0.05%)` | :arrow_up: |

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.