pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.01k stars 269 forks source link

Errors reported in this section of USING PRETRAINED MODELS #2081

Open Zhaohhya opened 2 months ago

Zhaohhya commented 2 months ago

I found a problem in the section USING PRETRAINED MODELS, the author did not have an error because the data is running CPU the whole time, but if the model and R3M are placed on the GPU, because the r3M is placed on the cpu by default, when sampling, r3M = ReplayBuffer(storage=storage, transform=r3m) it will automatically transform the data, at this time the data is on the CPU, and the R3M is on the GPU, it will report an error

Zhaohhya commented 2 months ago

Here's how I modified it. rb = ReplayBuffer(storage=storage, transform=r3m)modified to rb = ReplayBuffer(storage=storage) Add

batch = rb.sample(32)
batch = batch.to(device)
transformed_batch = r3m(batch)