Open seigot opened 2 years ago
Macだと問題なく動くので、cuda周りの有無が影響していそうである
これで起こらなくなる
diff --git a/game_manager/machine_learning/block_controller_train.py b/game_manager/machine_learning/block_controller_train.py
index 27b4786..2b32ffe 100644
--- a/game_manager/machine_learning/block_controller_train.py
+++ b/game_manager/machine_learning/block_controller_train.py
@@ -101,8 +101,8 @@ class Block_Controller(object):
self.model = torch.load(self.load_weight)
self.model.eval()
- if torch.cuda.is_available():
- self.model.cuda()
+# if torch.cuda.is_available():
+# self.model.cuda()
#=====Set hyper parameter=====
self.batch_size = cfg.train.batch_size
@@ -492,8 +492,8 @@ class Block_Controller(object):
next_actions, next_states = zip(*next_steps.items())
next_states = torch.stack(next_states)
- if torch.cuda.is_available():
- next_states = next_states.cuda()
+# if torch.cuda.is_available():
+# next_states = next_states.cuda()
self.model.train()
with torch.no_grad():
@@ -516,8 +516,8 @@ class Block_Controller(object):
next2_steps =self.get_next_func(next_backboard,next_piece_id,next_shape_class)
next2_actions, next2_states = zip(*next2_steps.items())
next2_states = torch.stack(next2_states)
- if torch.cuda.is_available():
- next2_states = next2_states.cuda()
+# if torch.cuda.is_available():
+# next2_states = next2_states.cuda()
self.model.train()
with torch.no_grad():
next_predictions = self.model(next2_states)[:, 0]
@@ -530,8 +530,8 @@ class Block_Controller(object):
next2_steps =self.get_next_func(next_backboard,next_piece_id,next_shape_class)
next2_actions, next2_states = zip(*next2_steps.items())
next2_states = torch.stack(next2_states)
- if torch.cuda.is_available():
- next2_states = next2_states.cuda()
+# if torch.cuda.is_available():
+# next2_states = next2_states.cuda()
self.target_model.train()
with torch.no_grad():
next_predictions = self.target_model(next2_states)[:, 0]
@@ -544,8 +544,8 @@ class Block_Controller(object):
next2_steps =self.get_next_func(next_backboard,next_piece_id,next_shape_class)
next2_actions, next2_states = zip(*next2_steps.items())
next2_states = torch.stack(next2_states)
- if torch.cuda.is_available():
- next2_states = next2_states.cuda()
+# if torch.cuda.is_available():
+# next2_states = next2_states.cuda()
self.model.train()
with torch.no_grad():
next_predictions = self.model(next2_states)[:, 0]
Hey, I'm a bit lost here! Not sure which file I should be fixing. Could you give me a bit more to go on? Maybe add some details to the issue or drop a comment with some extra hints? Thanks!
Have feedback or need help? Feel free to email info@gitauto.ai.