coax-dev / coax

Modular framework for Reinforcement Learning in python
https://coax.readthedocs.io
MIT License
168 stars 17 forks source link

Incorporating jax.jit into a customer policy #20

Closed UweGensheimer closed 2 years ago

UweGensheimer commented 2 years ago

I'm a bit new to JAX so my question might sound very naive. Suppose we are trying to solve a policy optimization problem through REINFORCE algorithm and suppose we already have our environment at hand (env). We define our customer policy as follows,

class CustomPolicy(hk.Module):
    def __init__(self, name = None):
        super().__init__(name = name)

    def __call__(self, x):
        w = hk.get_parameter("w", shape= ... , dtype = x.dtype, init=jnp.zeros)
        # some computation
        return out

Per the documentation, then we define

def custom_policy(S, is_training=True):
    logits = CustomPolicy()
    return {'logits': logits(S)}

and finally the policy is stated as follows,

pi = coax.Policy(custom_policy, env)

I was wondering is there any way to incorporate @jax.jit into this structure to further quicken the performance. Thanks.

KristianHolsheimer commented 2 years ago

Hi there, thanks for your interest in coax!

If you want to can put any number of jax.jit's in your custom policy, yes. I wouldn't expect it to make much of a difference, though, as the underlying haiku-transformed function will be jitted anyway (here).

Does that make sense?

UweGensheimer commented 2 years ago

Oh I see .. perfect. Thanks again! Indeed COAX is such a life saver. Thanks again for providing this framework. I hope I can contribute to it in the future too.

KristianHolsheimer commented 2 years ago

That's wonderful to hear, I'm happy that you find it useful!