pyrddlgym-project / pyRDDLGym

A toolkit for auto-generation of OpenAI Gym environments from RDDL description files.
https://pyrddlgym.readthedocs.io/
MIT License
68 stars 17 forks source link

Jax per node parameters #176

Closed mike-gimelfarb closed 1 year ago

mike-gimelfarb commented 1 year ago
mike-gimelfarb commented 1 year ago

Ilia, I have added you as reviewer if you have any suggestions here, since you've worked with the JAX already. I am curious if you think the current design change will be sufficient to handle some of the ideas we already discussed. I am still thinking about how to keep track of computation results of intermediate calculations.

mike-gimelfarb commented 1 year ago

I've updated the JaxExample how retrieving model tuning parameters works

mike-gimelfarb commented 1 year ago

I've looked at your branch briefly. I think we can avoid the nan problem by feeding a better value of the error parameter for floor (it is not weight here, but something with a different interpretation altogether. I set this too small based on some simple unit tests, and it was previously hard coded.). This pull should fix that problem.

iliathesmirnov commented 1 year ago

I've looked at your branch briefly. I think we can avoid the nan problem by feeding a better value of the error parameter for floor (it is not weight here, but something with a different interpretation altogether. I set this too small based on some simple unit tests, and it was previously hard coded.). This pull should fix that problem.

Hmm, why not make the error parameter a temperature-like parameter also, and make it locally-adjustable along with every other temperature parameter?

Sorry Mike you move too fast for me to keep up. I have to review a paper today (that was due Monday...) but as soon as I'm finished with that I'll take a look at the pull request!

mike-gimelfarb commented 1 year ago

I've looked at your branch briefly. I think we can avoid the nan problem by feeding a better value of the error parameter for floor (it is not weight here, but something with a different interpretation altogether. I set this too small based on some simple unit tests, and it was previously hard coded.). This pull should fix that problem.

Hmm, why not make the error parameter a temperature-like parameter also, and make it locally-adjustable along with every other temperature parameter?

Sorry Mike you move too fast for me to keep up. I have to review a paper today (that was due Monday...) but as soon as I'm finished with that I'll take a look at the pull request!

The error parameter will be tunable as temperature now, and can be set to a reasonable value like 10^-6 to avoid the nan problem. I am a bit reluctant to call it "temperature" since it does not have the same interpretation nor valid range as temperature for sigmoids (e.g. cannot be > 1), so it is simply called "error". Of course, the name can be changed arbitrarily in the JaxRDDLLogic rules without requiring any changes elsewhere.

I'd be interested to design a simple loss function that updates these parameters using SGD on rollouts from the true model. I don't think this is hard to do now.

iliathesmirnov commented 1 year ago

The error parameter will be tunable as temperature now, and can be set to a reasonable value like 10^-6 to avoid the nan problem. I am a bit reluctant to call it "temperature" since it does not have the same interpretation nor valid range as temperature for sigmoids (e.g. cannot be > 1), so it is simply called "error". Of course, the name can be changed arbitrarily in the JaxRDDLLogic rules without requiring any changes elsewhere.

OK, well, I know the current floor/ceil relaxations are already tuned and ready to go, but if needed later it's possible to come closer to the sigmoid temparature semantics, using something like the below

https://www.desmos.com/calculator/dewlaqc4wb

ceil(x;w) = ceil(x) + 1 / (1 + ((x-floor(x))/(1-x+floor(x))^(-(1+w)) )

Though the above is still a little off from the sigmoid semantics (the exponent at a given weight is larger by 1, so the extra weight term in the gradients is also larger by 1), it's closer.

I tried to take the gradient of this relaxation in JAX and seems to work

I'd be interested to design a simple loss function that updates these parameters using SGD on rollouts from the true model. I don't think this is hard to do now.

Yes... I don't have any ideas better than just mean squared error. But seeing what (if anything) is going wrong with MSE might point the way to a better loss function design? Maybe you already have another good idea.

Edit: Whoops, well, the ceil relaxation above also needs to be shifted to the right by 0.5, but that's just an extra translation (use ceil(x-0.5;w) instead).

mike-gimelfarb commented 1 year ago

Interesting! I did not know this approximation of ceiling. This is really promising, because one of the problems of the current one is that it is not very sharp around the integers, and indeed requires a small parameter. I'll make a simple comparison to see what's going on vs the current one. I also like that it has the same interpretation.

I don't have any better ideas besides MSE either currently. I still like the dual gradient descent idea where we take steps to optimize temperature, and interleave them with policy optimization (maybe the tolerance on model error is also decreased gradually). I suspect if the balance between these two is right, then the annealing will help to avoid local optima during policy optimization to some degree. But maybe not.. I can write a small example for this next week to test out these ideas.

mike-gimelfarb commented 1 year ago

Your approach is actually not far from the sigmoid interpretation. It can be written as ceil(x) + sigmoid(w * log(r/(1-r)) where r is the difference x - floor(x) (of course, taking x - 0.5 instead). I changed the implementation to yours, since it seems to work without numerical issues if implemented this way for very large w.

iliathesmirnov commented 1 year ago

Your approach is actually not far from the sigmoid interpretation. It can be written as ceil(x) + sigmoid(w * log(r/(1-r)) where r is the difference x - floor(x) (of course, taking x - 0.5 instead). I changed the implementation to yours, since it seems to work without numerical issues if implemented this way for very large w.

Yeah, it is a staircase of sigmoids, analogous to ceil being a staircase of step functions : ) With the difference that these sigmoids go to 0/1 over a finite interval, not at +/- infinity. Also, I shifted the weight by adding 1 because the curve is a linear ramp at w=1 and gets undesirable geometry at w < 1

These can have their temperature adjusted algorithmically like every other object parametrized by temperature. And they have a similar interpretation, so I hope a uniform algorithm can be used for all.