Closed UweGensheimer closed 2 years ago
Hi there, thanks for your interest in coax!
If you want to can put any number of jax.jit
's in your custom policy, yes. I wouldn't expect it to make much of a difference, though, as the underlying haiku-transformed function will be jitted anyway (here).
Does that make sense?
Oh I see .. perfect. Thanks again! Indeed COAX is such a life saver. Thanks again for providing this framework. I hope I can contribute to it in the future too.
That's wonderful to hear, I'm happy that you find it useful!
I'm a bit new to JAX so my question might sound very naive. Suppose we are trying to solve a policy optimization problem through REINFORCE algorithm and suppose we already have our environment at hand (env). We define our customer policy as follows,
Per the documentation, then we define
and finally the policy is stated as follows,
pi = coax.Policy(custom_policy, env)
I was wondering is there any way to incorporate @jax.jit into this structure to further quicken the performance. Thanks.