pyrddlgym-project / pyRDDLGym

A toolkit for auto-generation of OpenAI Gym environments from RDDL description files.
https://pyrddlgym.readthedocs.io/
MIT License
68 stars 17 forks source link

Questions regarding changes to JaxRDDLCompiler - params to jax_expr #186

Closed pecey closed 1 year ago

pecey commented 1 year ago

There has been some changes in JaxRDDLCompiler which has altered the lambda functions returned for CPF evaluation. In the file the logic is jax_cpfs[cpf] = self._jax(expr, info, dtype=dtype). Earlier expr was a lambda function that had two params - 1. dictionary of state, action and interim variables with their current values and 2. a PRNG key.

Now I think it is expecting three values. I couldn't find documentation of what the three params should be. Can someone please let me know where should I be looking or just explain what the three params that expr expects now?

Thank you.

mike-gimelfarb commented 1 year ago

Hi,

If I am understanding correctly, you are referring to the new 'params' argument in the wrapped jax expressions for RDDL calculations.

In short, this refers to a dictionary of per-node weight parameters that provide fine control of the model relaxations for discrete calculations approximated by parameterized expressions, aka sigmoid.

You can currently define and tune these per node weights if you like (e.g. using Bayesian optimizataion). There is an unused function 'print_parameterized_exprs' in the JaxExample that you could call to retrieve the keys it expects as well as their current values.

===

The story behind this is that, in the previous version, there was a single global tuning parameter 'w' to control the accuracy of the relaxations, e.g. x >= y -> sigmoid(w * (x - y)) in FuzzyLogic. However, in principle, it is possible to use per-node weight parameters where each 'w' can be locally tuned, e.g. using some local errors. We did not want to limit users who wish to adapt these parameters and have better control over the model approximation. It is not really used anywhere, nor is it currently clear how to without having better control of intermediate calculation in jax, something we like to work on in the future. (FuzzyLogic currently defines these parameters for some relaxations, so you can look there to see the technical details how they are propagated.)

mike-gimelfarb commented 1 year ago

In the future, 'params' could also be used for propagating other information/parameters through the computation graph that one does not want to bake in, so it is really meant as a "catch-all" for propagating information through Jax.

pecey commented 1 year ago

Thank you for the pointers @mike-gimelfarb. I will have a look at them.