Martyn0324 / SerpentAI

Game Agent Framework. Helping you create AIs / Bots that learn to play any game you own!
http://serpent.ai
MIT License
26 stars 5 forks source link

SerpentAI KeyError: -1 (with self.current_action variable) #9

Open Martyn0324 opened 2 years ago

Martyn0324 commented 2 years ago

Expected result

Use RainbowDQN in TRAIN mode to check how the model would behave

Encountered result

Got a KeyError: -1 for self.current_action found in the file rainbow_dqn_agent.py

Steps to reproduce

  1. Open your Game Agent Plugin code
  2. In setup_play, insert self.game_agent.set_mode(1) - 0 is the default, which is equivalent to OBSERVE mode in Rainbow. 1 is for TRAIN.
  3. Run your code

Apparently, the self.current_action variable, which is -1 when RainbowDQN is started, isn't reassigned during the process of choosing an action within the function generate_actions(self, state, **kwargs)

Why this happens, I don't know. Maybe instead of using self.game_agent.mode(1) I should use self.game_agent.mode(RainbowDQNAgentModes.TRAIN).

This reassignment problem doesn't happens with OBSERVE mode, used by default, as one can see in line 160:


        self.current_action = -1

        self.observe_mode = "RANDOM"
        self.set_mode(RainbowDQNAgentModes.OBSERVE)

        self.model = self.agent_kwargs["model"]

        if os.path.isfile(self.model):
            self.observe_mode = "MODEL"
            self.restore_model()

This problem can be solved by going to generate_actions() function and replacing:

        elif self.mode == RainbowDQNAgentModes.TRAIN:
            self.agent.reset_noise()
            self.current_action = self.agent.act(self.current_state)

for

        elif self.mode == 1:
            self.agent.reset_noise()
            self.current_action = self.agent.act(self.current_state)

In the beginning of the RainbowDQNAgent code, one can see this class:

class RainbowDQNAgentModes(enum.Enum):
    OBSERVE = 0
    TRAIN = 1
    EVALUATE = 2

Perhaps we should consider eliminating it and simply exchange RainbowDQNAgentModes.OBSERVE/TRAIN/EVALUATE for 0/1/2.