Open ShahRutav opened 7 months ago
On it! Will ping you soon with a solution
I edited https://github.com/facebookresearch/agenthive/pull/22 You can have a look, this example should work fine now
import torch
import robohive
print(robohive.robohive_env_suite)
# from torchrl.envs import RoboHiveEnv
# from torchrl.envs import ParallelEnv, TransformedEnv, R3MTransform
from rlhive.rl_envs import make_r3m_env
# from torchrl.collectors.collectors import SyncDataCollector, MultiaSyncDataCollector, RandomPolicy
# make sure your ParallelEnv is inside the `if __name__ == "__main__":` condition, otherwise you'll
# be creating an infinite tree of subprocesses
if __name__ == "__main__":
device = torch.device("cpu") # could be 'cuda:0'
env_name = 'FrankaReachFixed-v0'
env = make_r3m_env(env_name, model_name="resnet18", download=True)
assert env.device == device
# example of a rollout
print(env.rollout(3))
Thanks, env.rollout(3)
works after these changes. Taking the test example a step further, I collected data with a single process SyncDataCollector and multi-proc MultiaSyncDataCollector. Below is the code snippet,
import torch
import robohive
# printing the envs in robohive env_suite
# print(robohive.robohive_env_suite)
from rlhive.rl_envs import make_r3m_env
from torchrl.collectors.collectors import SyncDataCollector, MultiaSyncDataCollector, RandomPolicy
# make sure your ParallelEnv is inside the `if __name__ == "__main__":` condition, otherwise you'll
# be creating an infinite tree of subprocesses
if __name__ == "__main__":
device = torch.device("cpu") # could be 'cuda:0'
env_name = 'FrankaReachFixed-v0'
env = make_r3m_env(env_name, model_name="resnet18", download=True)
assert env.device == device
# example of a rollout
print(env.rollout(3))
# a simple, single-process data collector
collector = SyncDataCollector(env, policy=RandomPolicy(env.action_spec), total_frames=1_000, frames_per_batch=200, init_random_frames=200, )
for data in collector:
print(data)
# async multi-proc data collector
collector = MultiaSyncDataCollector([env, env], policy=RandomPolicy(env.action_spec), total_frames=1_000, frames_per_batch=200, init_random_frames=200, )
for data in collector:
print(data)
env.rollout(3)
works without any error.SyncDataCollector
works as well.MultiaSyncDataCollector
with two processes leads to an error,
File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/collectors/collectors.py", line 839, in rollout
tensordict, tensordict_ = self.env.step_and_maybe_reset(
File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/common.py", line 1942, in step_and_maybe_reset
tensordict = self.step(tensordict)
File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/common.py", line 1313, in step
next_tensordict = self._step(tensordict)
File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 735, in _step
next_tensordict = self.transform._step(tensordict, next_tensordict)
File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 970, in _step
next_tensordict = t._step(tensordict, next_tensordict)
File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 318, in _step
next_tensordict = self._call(next_tensordict)
File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 3681, in _call
raise Exception(
Exception: CatTensor failed, as it expected input keys = ['pixel_r3m', 'qp_robot', 'qv_robot', 'reach_err', 'solved'] but got a TensorDict with keys ['done', 'pixel_r3m', 'qp_robot', 'qv_robot', 'reach_err', 'reward', 'terminated', 'truncated']
The behavior is not the same in SyncDataCollector
and MultiaSyncDataCollector
.
I modified the getting started example to run torchrl with robohive. Here's the modified example,
Additionally, I changed this line to filter out the visual keys while concatenating R3M transform with other keys to
This leads to an error -
I am using the following versions of packages:
robohive==0.6.0
tensordict==0.2.1
torchrl==0.2.1
. Which version did you use? @vmoens