Open ankeshanand opened 4 years ago
In Pytorch >= 1.4, grad_norm is a torch tensor (changed in https://github.com/pytorch/pytorch/pull/32020) and not a float, so the logger throws an exception here (values is now a list of pytorch tensors): https://github.com/astooke/rlpyt/blob/a54cb5b1ee7b68d757aa0baa6a2786548419e366/rlpyt/utils/logging/logger.py#L457
values
To maintain backwards compatibility, an easy fix is to replace the append calls like https://github.com/astooke/rlpyt/blob/35af87255465b3644747294f7fd1ff6045dab910/rlpyt/algos/dqn/dqn.py#L184 with torch.tensor(grad_norm).item()
torch.tensor(grad_norm).item()
I am not sure if you want this to be fixed in logger though.
oh good idea on that bit of backward compatibility! thnx for posting. at some point will probably just move everything forward to 1.4 or 1.5, unless there is some reason not to?
In Pytorch >= 1.4, grad_norm is a torch tensor (changed in https://github.com/pytorch/pytorch/pull/32020) and not a float, so the logger throws an exception here (
values
is now a list of pytorch tensors): https://github.com/astooke/rlpyt/blob/a54cb5b1ee7b68d757aa0baa6a2786548419e366/rlpyt/utils/logging/logger.py#L457To maintain backwards compatibility, an easy fix is to replace the append calls like https://github.com/astooke/rlpyt/blob/35af87255465b3644747294f7fd1ff6045dab910/rlpyt/algos/dqn/dqn.py#L184 with
torch.tensor(grad_norm).item()
I am not sure if you want this to be fixed in logger though.