weipu-zhang / STORM

40 stars 10 forks source link

RuntimeError: Current CUDA Device does not support bfloat16. #4

Closed Liuxueyi closed 4 months ago

Liuxueyi commented 4 months ago

When I run the code on P100, this error occurs. How to deal with it? Looking forward to your reply. Following is the whole information: Namespace(n='Pong-life_done-wm_2L512D8H-100k-seed1', seed=1, config_path='config_files/STORM.yaml', env_name='ALE/Pong-v5', trajectory_path='D_TRAJ/Pong.pkl') A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7) [Powered by Stella] Current env: ALE/Pong-v5 0%| | 0/102000 [00:00<?, ?it/s]Saving model at total steps 0 1%|█▎ | 1024/102000 [00:01<02:26, 688.80it/s] Traceback (most recent call last): File "/home/lxy/mamba_dreamer/STORM-main/train.py", line 262, in <module> joint_train_world_model_agent( File "/home/lxy/mamba_dreamer/STORM-main/train.py", line 142, in joint_train_world_model_agent train_world_model_step( File "/home/lxy/mamba_dreamer/STORM-main/train.py", line 49, in train_world_model_step world_model.update(obs, action, reward, termination, logger=logger) File "/home/lxy/mamba_dreamer/STORM-main/sub_models/world_models.py", line 382, in update with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=self.use_amp): File "/home/lxy/anaconda3/envs/storm/lib/python3.9/site-packages/torch/amp/autocast_mode.py", line 234, in __init__ raise RuntimeError('Current CUDA Device does not support bfloat16. Please switch dtype to float16.') RuntimeError: Current CUDA Device does not support bfloat16. Please switch dtype to float16.

Liuxueyi commented 4 months ago

Sorry, I just ignore the instruction in readme.md.