ASEM000 / pytreeclass

Visualize, create, and operate on pytrees in the most intuitive way possible.
https://pytreeclass.rtfd.io/en/latest
Apache License 2.0
42 stars 2 forks source link

use instance variable #51

Closed ASEM000 closed 1 year ago

ASEM000 commented 1 year ago

This PR enables adding Field to instance variables of treeclass wrapped classes after initialization using .at functionality. Example, in torch we can do the following:

import torch 

class Tree(torch.nn.Module):
    def __init__(self, bias):
        super().__init__()
        self.bias = torch.nn.Parameter(torch.tensor(bias))

tree = Tree(0.)
tree.weight = torch.nn.Parameter(torch.tensor(3.))
list(tree.parameters())
# [Parameter containing:
#  tensor(0., requires_grad=True),
#  Parameter containing:
#  tensor(3., requires_grad=True)]

Torch overrides setattr to recognize values of Parameter andModule type to register them as model parameters. The examples show the ability to add parameters not defined in the init method; instead, it's being added after initialization.

Since PyTreeClass wrapped classes are immutable by default, setting an attribute after initialization is only possible using the .at functionality. .at creates a new instance with the updated value. However, in pytreeclass, fields (I.e. class parameters) are tied to the class, not the instance. This means no field can be added after initialization through the .at method. This PR adds this functionality, which is demonstrated in the following example:


import pytreeclass as pytc
from typing import Any

@pytc.treeclass
class Parameter:
    value: Any

@pytc.treeclass
class Tree:
    bias : int = 0 

    def add_param(self, name, param):
        return setattr(self, name, param)

tree = Tree()

_, tree_with_weight_param = tree.at['add_param']('weight', Parameter(3))

print(tree)
# Tree(bias=0)

print(tree_with_weight_param)
# Tree(bias=0, weight=Parameter(value=3))