Trying to follow the example for the Discretizer to limit the number of actions for Sonic 1. Fairly basic code and error is below:
import retro
import gym
import numpy as np
class Discretizer(gym.ActionWrapper):
"""
Wrap a gym environment and make it use discrete actions.
Args:
combos: ordered list of lists of valid button combinations
"""
def __init__(self, env, combos):
super().__init__(env)
assert isinstance(env.action_space, gym.spaces.MultiBinary)
buttons = env.unwrapped.buttons
self._decode_discrete_action = []
for combo in combos:
arr = np.array([False] * env.action_space.n)
for button in combo:
arr[buttons.index(button)] = True
self._decode_discrete_action.append(arr)
self.action_space = gym.spaces.Discrete(len(self._decode_discrete_action))
def action(self, act):
return self._decode_discrete_action[act].copy()
class SonicDiscretizer(Discretizer):
"""
Use Sonic-specific discrete actions
based on https://github.com/openai/retro-baselines/blob/master/agents/sonic_util.py
"""
def __init__(self, env):
super().__init__(env=env, combos=[['LEFT'], ['RIGHT'], ['LEFT', 'DOWN'], ['RIGHT', 'DOWN'], ['DOWN'], ['DOWN', 'B'], ['B'], ['LEFT', 'B'], ['RIGHT', 'B']])
env = retro.make('SonicTheHedgehog-Genesis', 'GreenHillZone.Act1')
env = SonicDiscretizer(env)
env.reset()
done = False
while not done:
env.render()
#action = env.action_space.sample()
action = [0,0,1,0,0,0,0,1,1]
ob, rew, done, info = env.step(action)
print("Action ", action, "Reward ", rew)
andreas@andreas-virtual-machine:~/Desktop$ python3 testsonic.py
Traceback (most recent call last):
File "testsonic.py", line 53, in <module>
ob, rew, done, info = env.step(action)
File "/home/andreas/.local/lib/python3.8/site-packages/gym/core.py", line 292, in step
return self.env.step(self.action(action))
File "testsonic.py", line 27, in action
return self._decode_discrete_action[act].copy()
TypeError: list indices must be integers or slices, not list
Trying to follow the example for the Discretizer to limit the number of actions for Sonic 1. Fairly basic code and error is below: