google-deepmind / mctx

Monte Carlo tree search in JAX
Apache License 2.0
2.33k stars 189 forks source link

Pass State into recurrent_fn #8

Closed evanatyourservice closed 2 years ago

evanatyourservice commented 2 years ago

Hello, thank you for open sourcing this great and concise implementation. I think it'd be useful to be able to pass network state into the recurrent function. WIth something like batch norm it could be done as part of params, but it'd be nice to have it separate. For my case I'm working with stateful rnn in the recurrent fn, so it would be nice to have network state separate and passed back into recurrent_fn like embedding is. I can copy the repo just as easily, but am wondering if you all wouldn't mind adding the feature to the repo as I think it could be useful for research and then we can still simply install the repo and use as is.

fidlej commented 2 years ago

Thank you for the friendly request. I want to understand more about your needs. Maybe you can pass the recurrent state as a part of the embedding. The embedding does not have to be a single array. The code allows the embedding to be a tree of arrays.

evanatyourservice commented 2 years ago

Thanks for the quick response! Oh ok excellent, that will work then, I just need haiku rnn state and a counter passed in, and the batch norm statistics. I see now all the embedding is done with tree maps so I think this should be fine right?