patrickloeber / snake-ai-pytorch

MIT License
600 stars 389 forks source link

how to load model ? model.pth #4

Closed begyy closed 2 years ago

begyy commented 2 years ago

when i start it learns again i need to import the model i have?)

begyy commented 2 years ago

 def get_action(self, state):
        # random moves: tradeoff exploration / exploitation
        self.epsilon = 80 - self.n_games
        final_move = [0,0,0]
        model = Linear_QNet(11, 256, 3)
        model.load_state_dict(torch.load('model/model.pth'))
        model.eval()
        state0 = torch.tensor(state, dtype=torch.float)
        prediction = model(state0)
        move = torch.argmax(prediction).item()
        final_move[move] = 1

        return final_move