jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.6k stars 2.82k forks source link

Function in class not recompiled after updating wrapper #9023

Open fsahli opened 2 years ago

fsahli commented 2 years ago

I am having trouble when updating a jitted function within a class. The function is not updated and my guess that is not being recompiled by jit after the update. It is worth noting that this only happens when jit is involved. The pure python version of this code works as expected. Here is a minimal example that reproduces this behavior.

from jax import jit
from functools import partial

class test:
    def __init__(self) -> None:
        pass
    @partial(jit, static_argnums=(0,))
    def wrapper(self,x):
        return self.fn(x)
    def assign_fn(self,fn):
        self.fn = fn

t = test()

fn = lambda x: x + 1
t.assign_fn(fn)
print(t.wrapper(1)) # prints 2, as expect

fn = lambda x: x + 2
t.assign_fn(fn)
print(t.wrapper(1)) # stills prints 2, although it should print 3
print(t.fn(1)) # prints 3
print(t.wrapper(1)) # stills prints 2

Is there a way to make this work? My understanding is any change in the static arguments would force recompilation, but I am not sure how this work the static argument is a class.

Thanks

jakevdp commented 2 years ago

The issue is when you declare an argument to a function static, it is only re-evaluated if its hash changes, and your class's hash does not change when fn is updated:

t = test()

t.assign_fn(lambda x: x + 1)
hash1 = hash(t)

t.assign_fn(lambda x: x + 2)
hash2 = hash(t)

hash1 == hash2
# True

There are two ways you could fix this:

  1. You could make your class behave as expected as a static argument by overriding __hash__ and __eq__ for your test class so that its hash and equality characteristics correctly depend on the identity of fn
  2. You could make your class behave correctly as a dynamic argument by registering your class as a pytree. Then it could be used within jit without having to mark it as a static input.

Option 2 is probably the cleaner approach.

fsahli commented 2 years ago

thanks for the prompt response, I will take a look