cbfinn / gps

Guided Policy Search
http://rll.berkeley.edu/gps/
Other
597 stars 239 forks source link

Mjcpy2 updates #12

Closed wmontgomery4 closed 8 years ago

wmontgomery4 commented 8 years ago

I'm implementing a deep Q-learner which 1) needs to run for many steps/episodes 2) needs to call world.get_data()['site_xpos'] every step to get the cost. Copying over all the data every step was slowing things down, so I modified mjcpy2 to return (state, site_xpos) instead of (state, oout) since no one ever seems to use 'oout'.

It looks like the GPS code only stores world.get_data() to get self._data()['site_xpos'] later, so it might speed things up here too, although I'd imagine trajopt/policyopt is a bigger bottleneck. Feel free to reject this pull request, just wanted to throw this out there. Maybe a better idea is just to add an mjcpy helper method which returns the current (qacc, qvel, site_xpos).

cbfinn commented 8 years ago

This is great - thanks! How much does this change speed up world.get_data()?

@zhangmarvin, please review.

FYI, @svlevine, this looks like a fix to what you mentioned yesterday.

zhangmarvin commented 8 years ago

It looks good to me! On Apr 12, 2016 5:34 PM, "Chelsea Finn" notifications@github.com wrote:

This is great - thanks! How much does this speed up world.get_data()?

@zhangmarvin https://github.com/zhangmarvin, please review.

FYI, @svlevine https://github.com/svlevine, this looks like a fix to what you mentioned yesterday.

— You are receiving this because you were mentioned. Reply to this email directly or view it on GitHub https://github.com/cbfinn/gps/pull/12#issuecomment-209165179

wmontgomery4 commented 8 years ago

I'd say my code is running ~5x faster now, but the change was made to world.step() and not world.get_data(). I basically made the change so that I could replace:

x1, oout = world.step(x0, u) # where 'oout' is never used according to old mjcpy code
site_xpos = world.get_data()['site_xpos']

with

x1, site_xpos = world.step(x0, u)

This is kind of a hack to get world.get_data() out of my main loop, but it might be useful for you too if you only need world.get_data()['site_xpos']. A better option would probably be to allow something like world.get_data('site_xpos') for getting single fields more efficiently, but I'm not that savvy with boost python.

For context, I'm implementing a deep Q-learning framework based on DDPG (http://arxiv.org/pdf/1509.02971v5.pdf). Right now I'm just building it parallel to GPS, but eventually it could be integrated into the GPS suite (especially since one plan is to try something like 'guided Q-learning').

cbfinn commented 8 years ago

Ok, cool. I'll merge this now then.

In the near future, I'll remove the call to world.get_data() in the agent_mjc sample() function. [Unless someone else wants to.]