facebookresearch / agenthive

AgentHive provides the primitives and helpers for a seamless usage of robohive within TorchRL.
30 stars 4 forks source link

Integration with torchRL #23

Open ShahRutav opened 7 months ago

ShahRutav commented 7 months ago

I modified the getting started example to run torchrl with robohive. Here's the modified example,

import torch
import robohive
from torchrl.envs import RoboHiveEnv
from torchrl.envs import ParallelEnv, TransformedEnv, R3MTransform

from rlhive.rl_envs import make_r3m_env
from torchrl.collectors.collectors import SyncDataCollector, MultiaSyncDataCollector, RandomPolicy
# make sure your ParallelEnv is inside the `if __name__ == "__main__":` condition, otherwise you'll
# be creating an infinite tree of subprocesses
if __name__ == "__main__":
    device = torch.device("cpu") # could be 'cuda:0'
    env_name = 'FrankaReachFixed-v0'
    env = make_r3m_env(env_name, model_name="resnet18", download=True)
    assert env.device == device
    # example of a rollout
    print(env.rollout(3))

Additionally, I changed this line to filter out the visual keys while concatenating R3M transform with other keys to

vec_keys = [k for k in base_env.observation_spec.keys() if ((k != "pixels") and ("visual" not in k))]

This leads to an error -

Traceback (most recent call last):
  File "test.py", line 16, in <module>
    print(env.rollout(3))
  File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/common.py", line 1797, in rollout
    tensordict = self.reset()
  File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/common.py", line 1480, in reset
    tensordict_reset = self._reset(tensordict, **kwargs)
  File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 760, in _reset
    tensordict_reset = self.transform._reset(tensordict, tensordict_reset)
  File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 1020, in _reset
    tensordict_reset = t._reset(tensordict, tensordict_reset)
  File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 3694, in _reset
    tensordict_reset = self._call(tensordict_reset)
  File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/torchrl/envs/transforms/transforms.py", line 3676, in _call
    out_tensor = torch.cat(values, dim=self.dim)
  File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/tensordict/tensordict.py", line 2785, in __torch_function__
    return TD_HANDLED_FUNCTIONS[func](*args, **kwargs)
  File "/Users/rutavms/miniconda3/envs/agenthive/lib/python3.8/site-packages/tensordict/tensordict.py", line 5346, in _cat
    batch_size = list(list_of_tensordicts[0].batch_size)
AttributeError: 'Tensor' object has no attribute 'batch_size'

I am using the following versions of packages: robohive==0.6.0 tensordict==0.2.1 torchrl==0.2.1. Which version did you use? @vmoens

vmoens commented 7 months ago

On it! Will ping you soon with a solution

vmoens commented 7 months ago

I edited https://github.com/facebookresearch/agenthive/pull/22 You can have a look, this example should work fine now

import torch
import robohive
print(robohive.robohive_env_suite)
# from torchrl.envs import RoboHiveEnv
# from torchrl.envs import ParallelEnv, TransformedEnv, R3MTransform

from rlhive.rl_envs import make_r3m_env
# from torchrl.collectors.collectors import SyncDataCollector, MultiaSyncDataCollector, RandomPolicy
# make sure your ParallelEnv is inside the `if __name__ == "__main__":` condition, otherwise you'll
# be creating an infinite tree of subprocesses
if __name__ == "__main__":
    device = torch.device("cpu") # could be 'cuda:0'
    env_name = 'FrankaReachFixed-v0'
    env = make_r3m_env(env_name, model_name="resnet18", download=True)
    assert env.device == device
    # example of a rollout
    print(env.rollout(3))
ShahRutav commented 7 months ago

Thanks, env.rollout(3) works after these changes. Taking the test example a step further, I collected data with a single process SyncDataCollector and multi-proc MultiaSyncDataCollector. Below is the code snippet,

import torch
import robohive
# printing the envs in robohive env_suite
# print(robohive.robohive_env_suite)

from rlhive.rl_envs import make_r3m_env
from torchrl.collectors.collectors import SyncDataCollector, MultiaSyncDataCollector, RandomPolicy
# make sure your ParallelEnv is inside the `if __name__ == "__main__":` condition, otherwise you'll
# be creating an infinite tree of subprocesses
if __name__ == "__main__":
    device = torch.device("cpu") # could be 'cuda:0'
    env_name = 'FrankaReachFixed-v0'
    env = make_r3m_env(env_name, model_name="resnet18", download=True)
    assert env.device == device
    # example of a rollout
    print(env.rollout(3))

    # a simple, single-process data collector
    collector = SyncDataCollector(env, policy=RandomPolicy(env.action_spec), total_frames=1_000, frames_per_batch=200, init_random_frames=200, )
    for data in collector:
        print(data)

    # async multi-proc data collector
    collector = MultiaSyncDataCollector([env, env], policy=RandomPolicy(env.action_spec), total_frames=1_000, frames_per_batch=200, init_random_frames=200, )
    for data in collector:
        print(data)

The behavior is not the same in SyncDataCollector and MultiaSyncDataCollector.