cgarciae / treeo

A small library for creating and manipulating custom JAX Pytree classes
https://cgarciae.github.io/treeo
MIT License
58 stars 4 forks source link

Use field kinds within tree_map #2

Closed thomaspinder closed 2 years ago

thomaspinder commented 3 years ago

Firstly, thanks for creating Treeo - it's a fantastic package.

Is there a way to use methods defined within a field's kind object within a tree_map call? For example, consider the following MWE

import jax.numpy as jnp

class Parameter:
    def transform(self):
        return jnp.exp(self)

@dataclass
class Model(to.Tree):
    lengthscale: jnp.array = to.field(
        default=jnp.array([1.0]), node=True, kind=Parameter
    )

is there a way that I could do something similar to the following pseudocode snippet:

m = Model()
jax.tree_map(lamdba x: x.transform(), to.filter(m, Parameter))
cgarciae commented 3 years ago

Hey @thomaspinder! Thanks for the kind words.

Let me first rule out the easy solution, curious if this is enough?

class Parameter:
    pass

def transform(x):
    return jnp.exp(x)

@dataclass
class Model(to.Tree):
    lengthscale: jnp.array = to.field(
        default=jnp.array([1.0]), node=True, kind=Parameter
    )

m = Model()
jax.tree_map(transform, to.filter(m, Parameter))

One thing to notice is that Kinds are just types that serve as metadata linked to a field but its not expected that they will be instantiated.

thomaspinder commented 3 years ago

Hey @cgarciae , thanks, but perhaps my MWE was an oversimplification. The reason for defining the transform as a method of the kind class is that there can be numerous classes e.g.,

class PositiveParameter():
    def transform(self):
        return jnp.abs(self)

class NegativeParameter():
    def transform(self):
        return jnp.array(-1.) * self 

and so on...

This makes the solution you've proposed a little more tricky as there'd need to be some awkward function mappings.

thomaspinder commented 3 years ago

Based on the tidy solution you've provided in #3 , one possible solution to this problem could be the following. Do you see any issues with this?

from dataclasses import dataclass
from typing import Set

import jax
import jax.numpy as jnp
import treeo as to
from treeo.utils import field

class KindOne:
    def transform(self):
        def transform_fn(x):
            return jnp.abs(x)
        return transform_fn

class KindTwo: 
    def transform(self):
        def transform_fn(x):
            return jnp.array(-1.) * x
        return transform_fn

@dataclass
class SubModel(to.Tree):
    parameter: jnp.ndarray = to.field(default=jnp.array([1.0]), node=True, kind=KindOne)

@dataclass
class Model(to.Tree):
    submodel: SubModel
    parameter: jnp.ndarray = to.field(default=jnp.array([1.0]), node=True, kind=KindTwo)

def unique_kinds(tree: to.Tree) -> Set[type]:
    kinds = set()

    def add_subtree_kinds(subtree: to.Tree):
        for field in subtree.field_metadata.values():
            if field.kind is not type(None):
                kinds.add(field.kind)

    to.apply(add_subtree_kinds, tree)

    return list(kinds)

sub_m = SubModel()
m = Model(submodel=sub_m)

for kind in unique_kinds(m):
    transform = kind().transform()
    m = to.map(transform, m, kind)
cgarciae commented 3 years ago

@thomaspinder I was guessing you where trying to do this 😅

Here is the solution:

from dataclasses import dataclass

import jax
import jax.numpy as jnp
import treeo as to
from treeo.utils import field

class Parameter:
    @staticmethod
    def transform(x):
        return jnp.exp(x)

@dataclass
class Model(to.Tree):
    lengthscale: jnp.ndarray = to.field(
        default=jnp.array([1.0]), node=True, kind=Parameter
    )

m = Model()

with to.add_field_info():
    m2 = jax.tree_map(lambda field: field.kind.transform(field.value), to.filter(m, Parameter))

print(m2)

The add_field_info function probably needs a section on the User Guide, what it does is that when flattening a Tree its leaves will all be of a type called FieldInfo which among other things contains the kind and value attributes which you can use to achieve what you want. Note that I've changed transform to be a staticmethod.

My thoughts is that if this pattern becomes more widespread it would be convenient to add a add_field_info: bool argument to to.map so you could write something like this:

m2 = to.map(
    lambda field: field.kind.transform(field.value), 
    to.filter(m, Parameter), 
    add_field_info=True,
)
cgarciae commented 3 years ago

BTW: Not sure if this is relevant to you but if you are doing something like this:

params = jax.tree_map(some_function, to.filter(m, Parameter))
m = to.merge(m, params)

You can simply use:

m = to.map(some_function, m, Parameter)
cgarciae commented 3 years ago

@thomaspinder sure! That solution based on #3 works. For ergonomics you can convert transform to be a staticmethod so you don't have to instantiate the kind.

thomaspinder commented 3 years ago

Thanks so much. The solution you give using with to.add_field_info()... is the perfect solution to my problem. Adding the additional argument to to.map() would be really great - if you ever want a hand with this e.g., writing tests/documentation, then I'd be happy to help you out.

cgarciae commented 3 years ago

@thomaspinder happy to guide you if you want to contribute 🌝 This issue looks very self contained, can be a good starting point.

Ping me if you need anything.

thomaspinder commented 2 years ago

Sure! I'd be happy to contribute. Are you able to outline the main steps that I should be mindful of when doing this?

cgarciae commented 2 years ago

I think adding a field_info: bool argument to map and then conditionally using the add_field_info context manager over this line should be enough:

https://github.com/cgarciae/treeo/blob/master/treeo/api.py#L197

Also try to add a test :smiley:. Sadly we don't have a contributing document yet but to start developing do the following:

  1. Install poetry
  2. Run poetry install to install dependencies
  3. Run poetry shell to activate environment.
  4. Run pre-commit install to install precommit hooks.
  5. Run pytest to run tests.