ray-project / ray

Ray is an AI compute engine. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
34.14k stars 5.8k forks source link

[rllib] DQNTrainer with TORCH framework fails when attempting to perform operations between tensors of different dtypes #9989

Closed scotfang closed 1 year ago

scotfang commented 4 years ago

What is the problem?

DQNTrainer with TORCH framework fails when attempting to perform operations between tensors of different dtypes. I think there are multiple places tensors should cast to float, but that is missing in the source code.

Ray version and other system information (Python version, TensorFlow version, OS): ray - 0.8.6 python - 3.6.10 pytorch - 1.2.0 OS - Ubuntu 16.04.6 LTS

Reproduction (REQUIRED)

Please provide a script that can be run to reproduce the issue. The script should have no external library dependencies (i.e., use fake or mock data / environments):

1) ray/rllib/examples/serving$ python cartpole_server.py --framework=torch 2) ray/rllib/examples/serving$ python cartpole_client.py

Error message below: Traceback (most recent call last): File "cartpole_server.py", line 83, in print(pretty_print(trainer.train())) File "/home/scotfang/envs/dest-dialog-py3_6/lib/python3.6/site-packages/ray/rllib/agents/trainer.py", line 500, in train raise e File "/home/scotfang/envs/dest-dialog-py3_6/lib/python3.6/site-packages/ray/rllib/agents/trainer.py", line 486, in train result = Trainable.train(self) File "/home/scotfang/envs/dest-dialog-py3_6/lib/python3.6/site-packages/ray/tune/trainable.py", line 261, in train result = self._train() File "/home/scotfang/envs/dest-dialog-py3_6/lib/python3.6/site-packages/ray/rllib/agents/trainer_template.py", line 132, in _train return self._train_exec_impl() File "/home/scotfang/envs/dest-dialog-py3_6/lib/python3.6/site-packages/ray/rllib/agents/trainer_template.py", line 170, in _train_exec_impl res = next(self.train_exec_impl) File "/home/scotfang/envs/dest-dialog-py3_6/lib/python3.6/site-packages/ray/util/iter.py", line 731, in next return next(self.built_iterator) File "/home/scotfang/envs/dest-dialog-py3_6/lib/python3.6/site-packages/ray/util/iter.py", line 744, in apply_foreach for item in it: File "/home/scotfang/envs/dest-dialog-py3_6/lib/python3.6/site-packages/ray/util/iter.py", line 814, in apply_filter for item in it: File "/home/scotfang/envs/dest-dialog-py3_6/lib/python3.6/site-packages/ray/util/iter.py", line 814, in apply_filter for item in it: File "/home/scotfang/envs/dest-dialog-py3_6/lib/python3.6/site-packages/ray/util/iter.py", line 744, in apply_foreach for item in it: File "/home/scotfang/envs/dest-dialog-py3_6/lib/python3.6/site-packages/ray/util/iter.py", line 814, in apply_filter for item in it: File "/home/scotfang/envs/dest-dialog-py3_6/lib/python3.6/site-packages/ray/util/iter.py", line 1047, in build_union item = next(it) File "/home/scotfang/envs/dest-dialog-py3_6/lib/python3.6/site-packages/ray/util/iter.py", line 731, in next return next(self.built_iterator) File "/home/scotfang/envs/dest-dialog-py3_6/lib/python3.6/site-packages/ray/util/iter.py", line 744, in apply_foreach for item in it: File "/home/scotfang/envs/dest-dialog-py3_6/lib/python3.6/site-packages/ray/util/iter.py", line 744, in apply_foreach for item in it: File "/home/scotfang/envs/dest-dialog-py3_6/lib/python3.6/site-packages/ray/util/iter.py", line 744, in apply_foreach for item in it: File "/home/scotfang/envs/dest-dialog-py3_6/lib/python3.6/site-packages/ray/util/iter.py", line 752, in apply_foreach result = fn(item) File "/home/scotfang/envs/dest-dialog-py3_6/lib/python3.6/site-packages/ray/rllib/execution/train_ops.py", line 69, in call info = self.workers.local_worker().learn_on_batch(batch) File "/home/scotfang/envs/dest-dialog-py3_6/lib/python3.6/site-packages/ray/rllib/evaluation/rollout_worker.py", line 713, in learn_on_batch info_out[pid] = policy.learn_on_batch(batch) File "/home/scotfang/envs/dest-dialog-py3_6/lib/python3.6/site-packages/ray/rllib/policy/torch_policy.py", line 233, in learn_on_batch self._loss(self, self.model, self.dist_class, train_batch)) File "/home/scotfang/envs/dest-dialog-py3_6/lib/python3.6/site-packages/ray/util/iter.py", line 814, in apply_filter for item in it: File "/home/scotfang/envs/dest-dialog-py3_6/lib/python3.6/site-packages/ray/util/iter.py", line 744, in apply_foreach for item in it: File "/home/scotfang/envs/dest-dialog-py3_6/lib/python3.6/site-packages/ray/util/iter.py", line 814, in apply_filter for item in it: File "/home/scotfang/envs/dest-dialog-py3_6/lib/python3.6/site-packages/ray/util/iter.py", line 1047, in build_union item = next(it) File "/home/scotfang/envs/dest-dialog-py3_6/lib/python3.6/site-packages/ray/util/iter.py", line 731, in next return next(self.built_iterator) File "/home/scotfang/envs/dest-dialog-py3_6/lib/python3.6/site-packages/ray/util/iter.py", line 744, in apply_foreach for item in it: File "/home/scotfang/envs/dest-dialog-py3_6/lib/python3.6/site-packages/ray/util/iter.py", line 744, in apply_foreach for item in it: File "/home/scotfang/envs/dest-dialog-py3_6/lib/python3.6/site-packages/ray/util/iter.py", line 744, in apply_foreach for item in it: File "/home/scotfang/envs/dest-dialog-py3_6/lib/python3.6/site-packages/ray/util/iter.py", line 752, in apply_foreach result = fn(item) File "/home/scotfang/envs/dest-dialog-py3_6/lib/python3.6/site-packages/ray/rllib/execution/train_ops.py", line 69, in call info = self.workers.local_worker().learn_on_batch(batch) File "/home/scotfang/envs/dest-dialog-py3_6/lib/python3.6/site-packages/ray/rllib/evaluation/rollout_worker.py", line 713, in learn_on_batch info_out[pid] = policy.learn_on_batch(batch) File "/home/scotfang/envs/dest-dialog-py3_6/lib/python3.6/site-packages/ray/rllib/policy/torch_policy.py", line 233, in learn_on_batch self._loss(self, self.model, self.dist_class, train_batch)) File "/home/scotfang/envs/dest-dialog-py3_6/lib/python3.6/site-packages/ray/rllib/agents/dqn/dqn_torch_policy.py", line 173, in build_q_losses q_t_selected = torch.sum(q_t * one_hot_selection, 1) RuntimeError: expected device cuda:0 and dtype Float but got device cuda:0 and dtype Long

I made some source code changes that stopped the error from occurring during training, but I'm not sure if the changes make sense:

diff ray/rllib/agents/dqn/dqn_torch_policy.py ray/rllib/agents/dqn/dqn_torch_policy.py.patched

<         q_t_selected_target = rewards + gamma**n_step * q_tp1_best_masked
---
>         q_t_selected_target = rewards.float() + gamma**n_step * q_tp1_best_masked.float()
172c172
<     q_t_selected = torch.sum(q_t * one_hot_selection, 1)
---
>     q_t_selected = torch.sum(q_t * one_hot_selection.float(), 1)
185c185
<         q_tp1_best = torch.sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
---
>         q_tp1_best = torch.sum(q_tp1 * q_tp1_best_one_hot_selection.float(), 1)
ericl commented 4 years ago

Maybe upgrade pytorch? It seems to run for me on 1.4.0