astooke / rlpyt

Reinforcement Learning in PyTorch
MIT License
2.22k stars 323 forks source link

Exception in logger with Pytorch >= 1.4 #152

Open ankeshanand opened 4 years ago

ankeshanand commented 4 years ago

In Pytorch >= 1.4, grad_norm is a torch tensor (changed in https://github.com/pytorch/pytorch/pull/32020) and not a float, so the logger throws an exception here (values is now a list of pytorch tensors): https://github.com/astooke/rlpyt/blob/a54cb5b1ee7b68d757aa0baa6a2786548419e366/rlpyt/utils/logging/logger.py#L457

To maintain backwards compatibility, an easy fix is to replace the append calls like https://github.com/astooke/rlpyt/blob/35af87255465b3644747294f7fd1ff6045dab910/rlpyt/algos/dqn/dqn.py#L184 with torch.tensor(grad_norm).item()

I am not sure if you want this to be fixed in logger though.

astooke commented 4 years ago

oh good idea on that bit of backward compatibility! thnx for posting. at some point will probably just move everything forward to 1.4 or 1.5, unless there is some reason not to?