pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.39k stars 319 forks source link

[BUG] info['_weight'] device for Importance Sampling in PER #2518

Open EladSharony opened 1 month ago

EladSharony commented 1 month ago

Describe the bug

The device of info['_weight'] doesn't match the storage device.

To Reproduce

# From documentation
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler
from tensordict import TensorDict
rb = ReplayBuffer(storage=LazyTensorStorage(10, device=torch.device('cuda')), sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0))
priority = torch.tensor([0, 1000])
data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, [])
data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, [])
rb.add(data_0)
rb.add(data_1)
rb.update_priority(torch.tensor([0, 1]), priority=priority)
sample, info = rb.sample(10, return_info=True)

# Check devices
print(f"sample device: {sample.device}\n"
      f"info['_weight'] device: {info['_weight'].device}")
sample device: cuda:0
info['_weight'] device: cpu

Expected behavior

Both should be on the same device defined in storage(..., device) as these weights are later used to compute the loss.

System info

import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
2024.10.23 1.26.4 3.10.15 (main, Oct  3 2024, 07:27:34) [GCC 11.2.0] linux

Reason and Possible fixes

Specify device argument in samplers.py (L508):

weight = torch.as_tensor(self._sum_tree[index], device=storage.device)

Checklist

vmoens commented 1 month ago

That and also we should be able to execute this directly on device. I'll push some changes

vmoens commented 1 month ago

Just FYI you could do this instead:

# From documentation
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler, TensorDictReplayBuffer
from tensordict import TensorDict
import torch

rb = TensorDictReplayBuffer(storage=LazyTensorStorage(10, device=torch.device('cuda')),
                  sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0))

priority = torch.tensor([0, 1000])
data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, [])
data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, [])
rb.add(data_0)
rb.add(data_1)
rb.update_priority(torch.tensor([0, 1]), priority=priority)
sample = rb.sample(10)

# Check devices
print(f"sample device: {sample.device}\n"
      f"sample['_weight'] device: {sample['_weight'].device}")

which will put your weights on cuda.

There are two issues in patching the PRB to account for the device of the storage:

  1. The issue you're having is caused by the fact that, for the ReplayBuffer class, the device of the storage is unknown, but it could be None. Also, the sampler is unaware of what the storage is. You could have multiple storages for instance. So in practice, if we want to cast the content of the info dict to the storage device, we would need to pass the storage device to the sampler and do that transfer. Another option could be for the buffer (and not the sampler) to do the casting if and only if the info dict is required (that would avoid useless H2D transfers when the info dict isn't asked for) but then we would still face the issue (2) below.

  2. If we map the info from the PRB to the device of the storage, it may still be incomplete. In the following example, I patch the sample method but also append a device map as a transform in the buffer. As this example shows, our transform will rightfully ignore the info dict:

    
    # From documentation
    import functools

from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler from tensordict import TensorDict import torch

device = "cuda"

patch

sample = PrioritizedSampler.sample @functools.wraps(sample) def new_sample(self, *args, *kwargs): out = sample(self, args, **kwargs) out = torch.utils._pytree.tree_map(lambda x: x.to(device), out) return out PrioritizedSampler.sample = new_sample

rb = ReplayBuffer(storage=LazyTensorStorage(10, device=torch.device(device)), sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0))

map back content on cpu

rb.append_transform(lambda x: x.to("cpu"))

priority = torch.tensor([0, 1000]) data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, []) data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, []) rb.add(data_0) rb.add(data_1) rb.update_priority(torch.tensor([0, 1]), priority=priority) sample, info = rb.sample(10, return_info=True)

Check devices

print(f"sample device: {sample.device}\n" f"info['_weight'] device: {info['_weight'].device}")



So to recap:
PRB is currenlty only hosted on CPU. It's the only part of the lib that relies on C++ code. The fact that the compuation is done on CPU is why you're getting info dict on cpu. Mapping to the storage device could be done We could do the sumtree and mintree on CUDA, that shouldn't be too hard. In the meantime we can send the info dict content to the storage device (see #2527) but that will only be an incomplete patch if you're not using `TensorDictReplayBuffer`.
EladSharony commented 1 month ago

Also, the sampler is unaware of what the storage is. You could have multiple storages for instance.

Maybe I'm missing something, but def sample(self, storage: Storage, batch_size: int) accepts the storage as an argument, thus we can query storage.device - which will also cover the multiple storages case.

If we map the info from the PRB to the device of the storage, it may still be incomplete. In the following example, I patch the sample method but also append a device map as a transform in the buffer. As this example shows, our transform will rightfully ignore the info dict:

That's a valid point. I wanted to suggest adding info to the data, but preallocating the memory might not be that trivial. On the other hand, I can't think of any reason (besides mapping a device) for which one will need to transform the info dict.