google-deepmind / mctx

Monte Carlo tree search in JAX
Apache License 2.0
2.33k stars 189 forks source link

Can we use neural network define by tensorflow as recurrent_fn in muzero? #6

Closed hejujie closed 2 years ago

hejujie commented 2 years ago

We are using muzero , and define network and loss with tensorflow and mcts with naive python sentences; But the mcts is very slow here, so we want to use mctx as a black box to replace our mcts, but use neural network in tensorflow. We are not familiar with jax, the issue is whether we can just use a tensorflow module as parameter recurrent_fn in muzero_policy.

fidlej commented 2 years ago

You may try tf2jax or jax2tf.call_tf.

hejujie commented 2 years ago

You may try tf2jax or jax2tf.call_tf.

Thanks!