Closed pipme closed 2 years ago
I looks like something related to ** 2
. The following gives 0 1 1 1 1
. But still not sure why 🤔
@jax.jit
def f(self, Y):
return jnp.sum(Y) * jnp.sum(self.X)
Ok it seems not related to treeo
.
@jax.jit
def f(Y):
return jnp.sum(Y**2)
Y = jnp.ones(2)
for i in range(5):
print(f._cache_size())
f(Y)
This alone gives 0 1 2 2 2
too. I will probably ask in Jax's repo if needed.
The output of the above is
0 1 2 2 2
with jax 0.3.15. No idea what's happening. It seems to work fine with 0.3.10 and the output is0 1 1 1 1
. Thanks.