nicklashansen / tdmpc2

Code for "TD-MPC2: Scalable, Robust World Models for Continuous Control"
https://www.tdmpc2.com
MIT License
272 stars 49 forks source link

Training on a custom gym env #17

Closed ChrisAGBlake closed 3 months ago

ChrisAGBlake commented 4 months ago

Any pointers on how a custom gym based env could be used? Assuming it's continuous control with a simple vector for the state.

nicklashansen commented 4 months ago

Hi @ChrisAGBlake, thanks for reaching out!

You'll need to add a make_env(cfg) function for your custom environment(s), along with relevant wrappers (if any). You can use the metaworld example here as a relatively minimal reference implementation. After that, just add your new env constructor to the list here: https://github.com/nicklashansen/tdmpc2/blob/57158282b46ebc5c329c5be9cfe2b0094126d1ca/tdmpc2/envs/__init__.py#L65 and it will automatically iterate through each constructor until it finds a match for your task argument.

Let me know if you run into any issues!

Edit: main branch assumes that all episodes are of equal length; use branch episodic-rl if that is not the case.

ChrisAGBlake commented 4 months ago

Thanks very much!

ChrisAGBlake commented 4 months ago

I'm currently getting the following error when trying to run on a custom env (I'm on the episodic-rl branch)

Traceback (most recent call last): File "/home/chris/Code/Libs/tdmpc2/tdmpc2/train.py", line 57, in train buffer=Buffer(cfg), File "/home/chris/Code/Libs/tdmpc2/tdmpc2/common/buffer.py", line 18, in init self._sampler = SliceSampler( TypeError: Can't instantiate abstract class SliceSampler with abstract methods load_state_dict, state_dict

I can replicate this error with the following code:

from common.samplers import SliceSampler
sampler = SliceSampler(num_slices=256, end_key=None, traj_key='episode', truncated_key=None)
nicklashansen commented 4 months ago

This looks like an error originating from the torchrl package. Do you get the same error when running any of the existing tasks? Upgrading or downgrading the version of this package might fix the problem in that case.

A potentially easier solution may be to use my prebuilt docker image for debugging purposes; this image should be able to run the code for sure. Instructions how to run the docker image can be found here.

ChrisAGBlake commented 4 months ago

Thanks, I downgraded the torchrl-nightly package to 2023.10.25 and it worked :)

nicklashansen commented 3 months ago

Closing this issue for now, but feel free to reopen if you have any follow-up questions!