cgarciae / treeo

A small library for creating and manipulating custom JAX Pytree classes
https://cgarciae.github.io/treeo
MIT License
58 stars 4 forks source link

Jitting twice for a class method #22

Closed pipme closed 2 years ago

pipme commented 2 years ago
import jax
import jax.numpy as jnp
import treeo as to

class A(to.Tree):
    X: jnp.array = to.field(node=True)

    def __init__(self):
        self.X = jnp.ones((50, 50))

    @jax.jit
    def f(self, Y):
        return jnp.sum(Y ** 2) * jnp.sum(self.X ** 2)

Y = jnp.ones(2)
for i in range(5):
    print(A.f._cache_size())
    a = A()
    a.f(Y)

The output of the above is 0 1 2 2 2 with jax 0.3.15. No idea what's happening. It seems to work fine with 0.3.10 and the output is 0 1 1 1 1. Thanks.

pipme commented 2 years ago

I looks like something related to ** 2. The following gives 0 1 1 1 1. But still not sure why 🤔

@jax.jit
def f(self, Y):
    return jnp.sum(Y) * jnp.sum(self.X)
pipme commented 2 years ago

Ok it seems not related to treeo.

@jax.jit
def f(Y):
    return jnp.sum(Y**2)

Y = jnp.ones(2)

for i in range(5):
    print(f._cache_size())
    f(Y)

This alone gives 0 1 2 2 2 too. I will probably ask in Jax's repo if needed.