Closed GIS-PuppetMaster closed 4 years ago
I think the best way to fix this is detect is the td_error is a tensor on gpu, and convert weight to a tensor saved on the same device
Could you please provide some other details? In my own Linux computer, it doesn't have such an exception.
--> Your os platform is win10? Did you use PER in TD3?
And add one line before critic1_loss = (td1.pow(2) * weight).mean()
:
print(td1.dtype, type(weight), weight.dtype, batch.weight)
Could you please provide some other details? In my own Linux computer, it doesn't have such an exception.
--> Your os platform is win10? Did you use PER in TD3?
And add one line before
critic1_loss = (td1.pow(2) * weight).mean()
:print(td1.dtype, type(weight), weight.dtype, batch.weight)
I'm running on win10 with PER in TD3. Oh, I customized my own PER based on last version, maybe it's because I didn't update it to this version, I will check my code and reply later.
Possibly it is because you create a new class which doesn't inherit the PrioritizedReplayBuffer
class, which fails the checking in BasePolicy
https://github.com/thu-ml/tianshou/blob/16d8e9b05182bb96c7fd3e46a922cfa1b862705c/tianshou/policy/base.py#L246-L248
Possibly it is because you create a new class which doesn't inherit the
PrioritizedReplayBuffer
class, which fails the checking in BasePolicy https://github.com/thu-ml/tianshou/blob/16d8e9b05182bb96c7fd3e46a922cfa1b862705c/tianshou/policy/base.py#L246-L248
I customized a StockReplayBuffer
class inherit the ReplayBuffer
class, and implement a StockPrioritizedReplayBuffer
class inherit the StockReplayBuffer
class, both of them overwrite the add
method.
But, if you want to use the prioritized experience replay's feature, plz inherit PrioritizedReplayBuffer instead of ReplayBuffer. Or change the line in BasePolicy as listed above to:
if isinstance(buffer, (PrioritizedReplayBuffer, StockReplayBuffer)):
batch.weight = to_torch_as(batch.weight, target_q_torch)
print(td1.dtype, type(weight), weight.dtype, batch.weight)
the output is
torch.float32
<class 'numpy.ndarray'>
float64
{AttributeError}'Batch' object has no attribute 'weight'
But, if you want to use the prioritized experience replay's feature, plz inherit PrioritizedReplayBuffer instead of ReplayBuffer. Or change the line in BasePolicy as listed above to:
if isinstance(buffer, (PrioritizedReplayBuffer, StockReplayBuffer)): batch.weight = to_torch_as(batch.weight, target_q_torch)
try to do this first and see if the issue is fixed
But, if you want to use the prioritized experience replay's feature, plz inherit PrioritizedReplayBuffer instead of ReplayBuffer. Or change the line in BasePolicy as listed above to:
if isinstance(buffer, (PrioritizedReplayBuffer, StockReplayBuffer)): batch.weight = to_torch_as(batch.weight, target_q_torch)
I used a stupid way by copying PrioritizedReplayBuffer code and edit it, I will try to inherit it or use the method you said, thanks for your helping!
But, if you want to use the prioritized experience replay's feature, plz inherit PrioritizedReplayBuffer instead of ReplayBuffer. Or change the line in BasePolicy as listed above to:
if isinstance(buffer, (PrioritizedReplayBuffer, StockReplayBuffer)): batch.weight = to_torch_as(batch.weight, target_q_torch)
try to do this first and see if the issue is fixed
I used multi-inherit, my own PER inherit my replay buffer class and PER from tianshou the same time, now it working normally, thank you very much!
But my code before working on linux is really weird, it shouldn't.
Hmm... I think it has a better way to solve this issue: change https://github.com/thu-ml/tianshou/blob/16d8e9b05182bb96c7fd3e46a922cfa1b862705c/tianshou/policy/base.py#L246-L248 to
if hasattr(batch, "weight"):
batch.weight = to_torch_as(batch.weight, target_q_torch)
and change https://github.com/thu-ml/tianshou/blob/c91def6cbcf572d3d2e645d7fec9f580bdb42e60/tianshou/policy/base.py#L141-L143 to
if hasattr(batch, "weight"):
buffer.update_weight(indice, batch.weight)
I'll fix it later.
Hmm... I think it has a better way to solve this issue: change https://github.com/thu-ml/tianshou/blob/16d8e9b05182bb96c7fd3e46a922cfa1b862705c/tianshou/policy/base.py#L246-L248
to
if hasattr(batch, "weight"): batch.weight = to_torch_as(batch.weight, target_q_torch)
and change
to
if hasattr(batch, "weight"): buffer.update_weight(indice, batch.weight)
I'll fix it later.
Indeed
Please have a look at #217
I'm running TD3 with the newest version of tianshou, and when calling td3.learn, there is an exception raised due to td1 is a tensor but weight is a numpy.ndarray, the pytorch automatically turn td1 tensor into numpy.ndarray. But I'm running on gpu, so here comes the exception. Here is the code that raised exception:
By the way, it's very odd that I can run this code normally on centos gpu server without this error.
Here is my env version: