Open EladSharony opened 1 month ago
That and also we should be able to execute this directly on device. I'll push some changes
Just FYI you could do this instead:
# From documentation
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler, TensorDictReplayBuffer
from tensordict import TensorDict
import torch
rb = TensorDictReplayBuffer(storage=LazyTensorStorage(10, device=torch.device('cuda')),
sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0))
priority = torch.tensor([0, 1000])
data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, [])
data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, [])
rb.add(data_0)
rb.add(data_1)
rb.update_priority(torch.tensor([0, 1]), priority=priority)
sample = rb.sample(10)
# Check devices
print(f"sample device: {sample.device}\n"
f"sample['_weight'] device: {sample['_weight'].device}")
which will put your weights on cuda.
There are two issues in patching the PRB to account for the device of the storage:
The issue you're having is caused by the fact that, for the ReplayBuffer
class, the device of the storage is unknown, but it could be None
. Also, the sampler is unaware of what the storage is. You could have multiple storages for instance. So in practice, if we want to cast the content of the info dict to the storage device, we would need to pass the storage device to the sampler and do that transfer. Another option could be for the buffer (and not the sampler) to do the casting if and only if the info dict is required (that would avoid useless H2D transfers when the info dict isn't asked for) but then we would still face the issue (2) below.
If we map the info from the PRB to the device of the storage, it may still be incomplete. In the following example, I patch the sample method but also append a device map as a transform in the buffer. As this example shows, our transform will rightfully ignore the info dict:
# From documentation
import functools
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage, PrioritizedSampler from tensordict import TensorDict import torch
device = "cuda"
sample = PrioritizedSampler.sample @functools.wraps(sample) def new_sample(self, *args, *kwargs): out = sample(self, args, **kwargs) out = torch.utils._pytree.tree_map(lambda x: x.to(device), out) return out PrioritizedSampler.sample = new_sample
rb = ReplayBuffer(storage=LazyTensorStorage(10, device=torch.device(device)), sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0))
rb.append_transform(lambda x: x.to("cpu"))
priority = torch.tensor([0, 1000]) data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, []) data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, []) rb.add(data_0) rb.add(data_1) rb.update_priority(torch.tensor([0, 1]), priority=priority) sample, info = rb.sample(10, return_info=True)
print(f"sample device: {sample.device}\n" f"info['_weight'] device: {info['_weight'].device}")
So to recap:
PRB is currenlty only hosted on CPU. It's the only part of the lib that relies on C++ code. The fact that the compuation is done on CPU is why you're getting info dict on cpu. Mapping to the storage device could be done We could do the sumtree and mintree on CUDA, that shouldn't be too hard. In the meantime we can send the info dict content to the storage device (see #2527) but that will only be an incomplete patch if you're not using `TensorDictReplayBuffer`.
Also, the sampler is unaware of what the storage is. You could have multiple storages for instance.
Maybe I'm missing something, but def sample(self, storage: Storage, batch_size: int) accepts the storage as an argument, thus we can query storage.device
- which will also cover the multiple storages case.
If we map the info from the PRB to the device of the storage, it may still be incomplete. In the following example, I patch the sample method but also append a device map as a transform in the buffer. As this example shows, our transform will rightfully ignore the info dict:
That's a valid point. I wanted to suggest adding info
to the data
, but preallocating the memory might not be that trivial. On the other hand, I can't think of any reason (besides mapping a device) for which one will need to transform the info
dict.
Describe the bug
The device of
info['_weight']
doesn't match the storage device.To Reproduce
Expected behavior
Both should be on the same device defined in
storage(..., device)
as these weights are later used to compute the loss.System info
Reason and Possible fixes
Specify
device
argument in samplers.py (L508):Checklist