thu-ml / tianshou

An elegant PyTorch deep reinforcement learning library.
https://tianshou.org
MIT License
8.02k stars 1.13k forks source link

PER weight does not support customized replay buffer #215

Closed GIS-PuppetMaster closed 4 years ago

GIS-PuppetMaster commented 4 years ago

I'm running TD3 with the newest version of tianshou, and when calling td3.learn, there is an exception raised due to td1 is a tensor but weight is a numpy.ndarray, the pytorch automatically turn td1 tensor into numpy.ndarray. But I'm running on gpu, so here comes the exception. Here is the code that raised exception:

def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
    weight = batch.pop('weight', 1.)
    # critic 1
    current_q1 = self.critic1(batch.obs, batch.act).flatten()
    target_q = batch.returns.flatten()
    td1 = current_q1 - target_q
    critic1_loss = (td1.pow(2) * weight).mean() ------exception
Epoch #1:   0%|          | 0/1 [00:21<?, ?it/s]
Traceback (most recent call last):
  File "D:/PycharmProjects/Stable-BaselineTrading/Tianshou/TD3.py", line 85, in <module>
    save_fn=lambda p: torch.save(p.state_dict(), save_dir))
  File "D:\PycharmProjects\Stable-BaselineTrading\Tianshou\Trainer\offpolicy.py", line 110, in offpolicy_trainer
    losses = policy.update(batch_size, train_collector.buffer)
  File "C:\Users\zkx74\Anaconda3\envs\RL\lib\site-packages\tianshou\policy\base.py", line 147, in update
    result = self.learn(batch, *args, **kwargs)
  File "C:\Users\zkx74\Anaconda3\envs\RL\lib\site-packages\tianshou\policy\modelfree\td3.py", line 123, in learn
    critic1_loss = (td1.pow(2) * weight).mean()
  File "C:\Users\zkx74\Anaconda3\envs\RL\lib\site-packages\torch\tensor.py", line 480, in __array__
    return self.numpy()
TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

By the way, it's very odd that I can run this code normally on centos gpu server without this error.

Here is my env version:

numpy 1.19.0
torch 1.6.0
cuda 10.1
tianshou 0.2.7
GIS-PuppetMaster commented 4 years ago

I think the best way to fix this is detect is the td_error is a tensor on gpu, and convert weight to a tensor saved on the same device

Trinkle23897 commented 4 years ago

Could you please provide some other details? In my own Linux computer, it doesn't have such an exception.

--> Your os platform is win10? Did you use PER in TD3?

And add one line before critic1_loss = (td1.pow(2) * weight).mean():

print(td1.dtype, type(weight), weight.dtype, batch.weight)
GIS-PuppetMaster commented 4 years ago

Could you please provide some other details? In my own Linux computer, it doesn't have such an exception.

--> Your os platform is win10? Did you use PER in TD3?

And add one line before critic1_loss = (td1.pow(2) * weight).mean():

print(td1.dtype, type(weight), weight.dtype, batch.weight)

I'm running on win10 with PER in TD3. Oh, I customized my own PER based on last version, maybe it's because I didn't update it to this version, I will check my code and reply later.

Trinkle23897 commented 4 years ago

Possibly it is because you create a new class which doesn't inherit the PrioritizedReplayBuffer class, which fails the checking in BasePolicy https://github.com/thu-ml/tianshou/blob/16d8e9b05182bb96c7fd3e46a922cfa1b862705c/tianshou/policy/base.py#L246-L248

GIS-PuppetMaster commented 4 years ago

Possibly it is because you create a new class which doesn't inherit the PrioritizedReplayBuffer class, which fails the checking in BasePolicy https://github.com/thu-ml/tianshou/blob/16d8e9b05182bb96c7fd3e46a922cfa1b862705c/tianshou/policy/base.py#L246-L248

I customized a StockReplayBuffer class inherit the ReplayBuffer class, and implement a StockPrioritizedReplayBuffer class inherit the StockReplayBuffer class, both of them overwrite the add method.

Trinkle23897 commented 4 years ago

But, if you want to use the prioritized experience replay's feature, plz inherit PrioritizedReplayBuffer instead of ReplayBuffer. Or change the line in BasePolicy as listed above to:

    if isinstance(buffer, (PrioritizedReplayBuffer, StockReplayBuffer)):
       batch.weight = to_torch_as(batch.weight, target_q_torch) 
GIS-PuppetMaster commented 4 years ago
print(td1.dtype, type(weight), weight.dtype, batch.weight)

the output is torch.float32 <class 'numpy.ndarray'> float64 {AttributeError}'Batch' object has no attribute 'weight'

Trinkle23897 commented 4 years ago

But, if you want to use the prioritized experience replay's feature, plz inherit PrioritizedReplayBuffer instead of ReplayBuffer. Or change the line in BasePolicy as listed above to:

    if isinstance(buffer, (PrioritizedReplayBuffer, StockReplayBuffer)):
       batch.weight = to_torch_as(batch.weight, target_q_torch) 

try to do this first and see if the issue is fixed

GIS-PuppetMaster commented 4 years ago

But, if you want to use the prioritized experience replay's feature, plz inherit PrioritizedReplayBuffer instead of ReplayBuffer. Or change the line in BasePolicy as listed above to:

    if isinstance(buffer, (PrioritizedReplayBuffer, StockReplayBuffer)):
       batch.weight = to_torch_as(batch.weight, target_q_torch) 

I used a stupid way by copying PrioritizedReplayBuffer code and edit it, I will try to inherit it or use the method you said, thanks for your helping!

GIS-PuppetMaster commented 4 years ago

But, if you want to use the prioritized experience replay's feature, plz inherit PrioritizedReplayBuffer instead of ReplayBuffer. Or change the line in BasePolicy as listed above to:

    if isinstance(buffer, (PrioritizedReplayBuffer, StockReplayBuffer)):
       batch.weight = to_torch_as(batch.weight, target_q_torch) 

try to do this first and see if the issue is fixed

I used multi-inherit, my own PER inherit my replay buffer class and PER from tianshou the same time, now it working normally, thank you very much!

GIS-PuppetMaster commented 4 years ago

But my code before working on linux is really weird, it shouldn't.

Trinkle23897 commented 4 years ago

Hmm... I think it has a better way to solve this issue: change https://github.com/thu-ml/tianshou/blob/16d8e9b05182bb96c7fd3e46a922cfa1b862705c/tianshou/policy/base.py#L246-L248 to

if hasattr(batch, "weight"):
    batch.weight = to_torch_as(batch.weight, target_q_torch)

and change https://github.com/thu-ml/tianshou/blob/c91def6cbcf572d3d2e645d7fec9f580bdb42e60/tianshou/policy/base.py#L141-L143 to

if hasattr(batch, "weight"):
    buffer.update_weight(indice, batch.weight)

I'll fix it later.

GIS-PuppetMaster commented 4 years ago

Hmm... I think it has a better way to solve this issue: change https://github.com/thu-ml/tianshou/blob/16d8e9b05182bb96c7fd3e46a922cfa1b862705c/tianshou/policy/base.py#L246-L248

to

if hasattr(batch, "weight"):
    batch.weight = to_torch_as(batch.weight, target_q_torch)

and change

https://github.com/thu-ml/tianshou/blob/c91def6cbcf572d3d2e645d7fec9f580bdb42e60/tianshou/policy/base.py#L141-L143

to

if hasattr(batch, "weight"):
    buffer.update_weight(indice, batch.weight)

I'll fix it later.

Indeed

Trinkle23897 commented 4 years ago

Please have a look at #217