Closed thomaspinder closed 2 years ago
Hey @thomaspinder! Thanks for the kind words.
Let me first rule out the easy solution, curious if this is enough?
class Parameter:
pass
def transform(x):
return jnp.exp(x)
@dataclass
class Model(to.Tree):
lengthscale: jnp.array = to.field(
default=jnp.array([1.0]), node=True, kind=Parameter
)
m = Model()
jax.tree_map(transform, to.filter(m, Parameter))
One thing to notice is that Kinds are just types that serve as metadata linked to a field but its not expected that they will be instantiated.
Hey @cgarciae , thanks, but perhaps my MWE was an oversimplification. The reason for defining the transform as a method of the kind class is that there can be numerous classes e.g.,
class PositiveParameter():
def transform(self):
return jnp.abs(self)
class NegativeParameter():
def transform(self):
return jnp.array(-1.) * self
and so on...
This makes the solution you've proposed a little more tricky as there'd need to be some awkward function mappings.
Based on the tidy solution you've provided in #3 , one possible solution to this problem could be the following. Do you see any issues with this?
from dataclasses import dataclass
from typing import Set
import jax
import jax.numpy as jnp
import treeo as to
from treeo.utils import field
class KindOne:
def transform(self):
def transform_fn(x):
return jnp.abs(x)
return transform_fn
class KindTwo:
def transform(self):
def transform_fn(x):
return jnp.array(-1.) * x
return transform_fn
@dataclass
class SubModel(to.Tree):
parameter: jnp.ndarray = to.field(default=jnp.array([1.0]), node=True, kind=KindOne)
@dataclass
class Model(to.Tree):
submodel: SubModel
parameter: jnp.ndarray = to.field(default=jnp.array([1.0]), node=True, kind=KindTwo)
def unique_kinds(tree: to.Tree) -> Set[type]:
kinds = set()
def add_subtree_kinds(subtree: to.Tree):
for field in subtree.field_metadata.values():
if field.kind is not type(None):
kinds.add(field.kind)
to.apply(add_subtree_kinds, tree)
return list(kinds)
sub_m = SubModel()
m = Model(submodel=sub_m)
for kind in unique_kinds(m):
transform = kind().transform()
m = to.map(transform, m, kind)
@thomaspinder I was guessing you where trying to do this 😅
Here is the solution:
from dataclasses import dataclass
import jax
import jax.numpy as jnp
import treeo as to
from treeo.utils import field
class Parameter:
@staticmethod
def transform(x):
return jnp.exp(x)
@dataclass
class Model(to.Tree):
lengthscale: jnp.ndarray = to.field(
default=jnp.array([1.0]), node=True, kind=Parameter
)
m = Model()
with to.add_field_info():
m2 = jax.tree_map(lambda field: field.kind.transform(field.value), to.filter(m, Parameter))
print(m2)
The add_field_info
function probably needs a section on the User Guide, what it does is that when flattening a Tree
its leaves will all be of a type called FieldInfo
which among other things contains the kind
and value
attributes which you can use to achieve what you want. Note that I've changed transform
to be a staticmethod
.
My thoughts is that if this pattern becomes more widespread it would be convenient to add a add_field_info: bool
argument to to.map
so you could write something like this:
m2 = to.map(
lambda field: field.kind.transform(field.value),
to.filter(m, Parameter),
add_field_info=True,
)
BTW: Not sure if this is relevant to you but if you are doing something like this:
params = jax.tree_map(some_function, to.filter(m, Parameter))
m = to.merge(m, params)
You can simply use:
m = to.map(some_function, m, Parameter)
@thomaspinder sure! That solution based on #3 works. For ergonomics you can convert transform
to be a staticmethod
so you don't have to instantiate the kind.
Thanks so much. The solution you give using with to.add_field_info()...
is the perfect solution to my problem. Adding the additional argument to to.map()
would be really great - if you ever want a hand with this e.g., writing tests/documentation, then I'd be happy to help you out.
@thomaspinder happy to guide you if you want to contribute 🌝 This issue looks very self contained, can be a good starting point.
Ping me if you need anything.
Sure! I'd be happy to contribute. Are you able to outline the main steps that I should be mindful of when doing this?
I think adding a field_info: bool
argument to map
and then conditionally using the add_field_info
context manager over this line should be enough:
https://github.com/cgarciae/treeo/blob/master/treeo/api.py#L197
Also try to add a test :smiley:. Sadly we don't have a contributing document yet but to start developing do the following:
poetry
poetry install
to install dependenciespoetry shell
to activate environment.pre-commit install
to install precommit hooks.pytest
to run tests.
Firstly, thanks for creating Treeo - it's a fantastic package.
Is there a way to use methods defined within a field's
kind
object within atree_map
call? For example, consider the following MWEis there a way that I could do something similar to the following pseudocode snippet: