hakuhodo-technologies / scope-rl

SCOPE-RL: A python library for offline reinforcement learning, off-policy evaluation, and selection
https://scope-rl.readthedocs.io/en/latest/
Apache License 2.0
106 stars 10 forks source link

Runtime error when you try to create a tensor using arg "devcie ='cuda' in ContinuousDiceStateActionWightValueLearning #31

Open xiaobaijh opened 3 months ago

xiaobaijh commented 3 months ago

In classs ContinuousDiceStateActionWightValueLearning ,line 400 we use state = torch.FloatTensor(state, device=self.device) to create a tensor,but if self.device is 'cuda',it will throw error 'RuntimeError: legacy constructor expects device type: cpu but device type: cuda was passed',it should be state = torch.tensor(state, device=self.device).So as the other tensors in this class.