In classs ContinuousDiceStateActionWightValueLearning ,line 400 we use state = torch.FloatTensor(state, device=self.device) to create a tensor,but if self.device is 'cuda',it will throw error
'RuntimeError: legacy constructor expects device type: cpu but device type: cuda was passed',it should be state = torch.tensor(state, device=self.device).So as the other tensors in this class.
In classs ContinuousDiceStateActionWightValueLearning ,line 400 we use
state = torch.FloatTensor(state, device=self.device)
to create a tensor,but ifself.device
is 'cuda',it will throw error 'RuntimeError: legacy constructor expects device type: cpu but device type: cuda was passed',it should bestate = torch.tensor(state, device=self.device)
.So as the other tensors in this class.