araffin / rl-baselines-zoo

A collection of 100+ pre-trained RL agents using Stable Baselines, training and hyperparameter optimization included.
https://stable-baselines.readthedocs.io/
MIT License
1.12k stars 208 forks source link

[feature request] Execution Drivers and Standardized Agent Interface #34

Closed jmribeiro closed 5 years ago

jmribeiro commented 5 years ago

Currently the way to train an agent is to 1) Instantiate the environment, 2) Instantiate the agent, passing the environment in constructor and 3) calling the learn method.

Some agent frameworks have started implementing execution drivers, i.e., objects responsible for interacting the agent with the environment.

Would love to see such a feature in stable-baselines, given that it would facilitate a lot the pipeline for testing a new agent and comparing it with existing ones.

Current code: env = gym.make("CartPole-v0") agent = DQN(env, ...) agent.learn(...)

What if there were driver objects such that the execution would go something like this:

driver = BottleneckDriver(max_timesteps=10000, max_episodes=200)
metrics = [TotalTimesteps(), AverageTimestepReward(), AverageEpisodeReward(), ...]
driver.run(agent, env, metrics)
for metric in metric:
      print(f"{metric.name}: {metric.result()}")

Example


class Driver(BaseDriver):
    """ 
      Runs until one of the conditions is met - max_timesteps or episodes
    """
    def __init__(self, agent, environment, max_timesteps=math.inf, max_episodes=math.inf, observers=None):
        super(BottleneckRunner, self).__init__(agent, environment, observers)
        self._timesteps = max_timesteps
        self._episodes = max_episodes

    def run(self):
        self._environment.reset()
        done = False
        while not done:
            self.step()
            done = self.total_episodes >= self._episodes or self.total_steps >= self._timesteps
And a base class:

Timestep = namedtuple("Timestep", "t state action reward next_state is_terminal info")

class BaseDriver(ABC):

    def __init__(self, agent, environment, observers):
        """
        :param agent: The agent to interact with the environment
        :param environment: The environment
        :param observers: The observers
        """
        self._agent = agent
        self._environment = environment
        self._observers = observers or []

        self._total_steps = 0
        self._total_episodes = 0

    @property
    def total_steps(self):
        return self._total_steps

    @property
    def total_episodes(self):
        return self._total_episodes

    @abstractmethod
    def run(self):
        raise NotImplementedError()

    def step(self):
        state = self._environment.state
        action = self._agent.action(state)
        timestep = self._environment.step(action)
        for observer in self._observers:
            observer(timestep)
            self._agent.reinforcement(timestep)
        self._total_steps += 1
        is_terminal = timestep.is_terminal
        if is_terminal:
            self._total_episodes += 1
        return timestep

    def episode(self):
        self._environment.reset()
        is_terminal = False
        trajectory = [self._environment.state]
        while not is_terminal:
            timestep = self.step()
            trajectory.append(timestep)
        return trajectory

Would love to contribute with such features. Let me know what you think.

araffin commented 5 years ago

Closing this issue in favor of this one: https://github.com/hill-a/stable-baselines/issues/381