alirezamika / evostra

A fast Evolution Strategy implementation in Python
MIT License
268 stars 46 forks source link

Add Early Stopping #11

Open vrishank97 opened 4 years ago

KirkSuD commented 2 years ago

Wrote some Keras style callbacks before. Seemed to work for me. Probably not written well. Not sure if there's any bug. Leave it here in case someone wants to save some time. Feel free to edit & use.

I'm not an active user of evostra. I would hope callbacks to be part of the package if I were. (or evotra being part of Keras? I don't know.)

import os
import csv
import pickle
import multiprocessing

from tqdm import tqdm
from evostra import EvolutionStrategy

class EvolutionStrategyWithCallbacks(EvolutionStrategy):
    def __init__(self, *args, val_reward_func=None, callbacks=[], **kwargs):
        super().__init__(*args, **kwargs)
        self.val_reward_func = val_reward_func
        self.callbacks = callbacks

    def run(self, iterations, initial_iteration=0):
        pool = multiprocessing.Pool(self.num_threads) if self.num_threads > 1 else None
        iteration = initial_iteration
        self.iterations = iterations
        self.training = True
        history = dict(reward=[], lr=[])
        if self.val_reward_func:
            history["val_reward"] = []
        while iteration < iterations and self.training:
            iteration += 1
            population = self._get_population()
            rewards = self._get_rewards(pool, population)

            self._update_weights(rewards, population)

            current_iter_reward = self.get_reward(self.weights)
            history["reward"].append(current_iter_reward)
            history["lr"].append(self.learning_rate)
            if self.val_reward_func:
                current_iter_val_reward = self.val_reward_func(self.weights)
                history["val_reward"].append(current_iter_val_reward)
            else:
                current_iter_val_reward = None
            for cb in self.callbacks:
                cb(self, iteration, current_iter_reward, current_iter_val_reward)
        if pool is not None:
            pool.close()
            pool.join()
        return history

class ProgbarLogger:
    def __init__(self, print_new_line=True):
        self.print_new_line = print_new_line
        self.pbar = None

    def __call__(self, es, it, reward, val_reward):
        if self.pbar is None:
            self.pbar = tqdm(total=es.iterations)
        desc = f"Iter {it}/{es.iterations} reward: {reward:.6f}"
        if val_reward is not None:
            desc += f" val_reward: {val_reward:.6f}"
        self.pbar.desc = desc
        self.pbar.update()
        if self.print_new_line:
            print()

class EarlyStopping:
    def __init__(self, monitor_val=True, patience=0, verbose=0):
        self.monitor_val = monitor_val
        self.patience = patience
        self.verbose = verbose
        self.best_it, self.best_reward = 0, float("-inf")

    def __call__(self, es, it, reward, val_reward):
        if self.monitor_val:
            if val_reward is None:
                raise ValueError("val_reward is None. Did you pass val_reward_func?")
            r = val_reward
        else:
            r = reward
        if r > self.best_reward:
            self.best_it, self.best_reward = it, r
        if it - self.best_it >= self.patience:
            if self.verbose:
                reward_type = "val_reward" if self.monitor_val else "reward"
                print(
                    f"Earlystopping: iter {it}. {reward_type} {r}",
                    f"didn't increase from iter {self.best_it}. {self.best_reward}",
                )
            es.training = False

class ModelCheckpoint:
    def __init__(self, filepath, monitor_val=True, verbose=0, save_best_only=False):
        self.filepath = filepath
        self.monitor_val = monitor_val
        self.verbose = verbose
        self.save_best_only = save_best_only
        self.best_reward = float("-inf")

    def __call__(self, es, it, reward, val_reward):
        if self.save_best_only:
            if self.monitor_val:
                if val_reward is None:
                    raise ValueError(
                        "val_reward is None. Did you pass val_reward_func?"
                    )
                r = val_reward
            else:
                r = reward
            if r > self.best_reward:
                if self.verbose:
                    reward_type = "val_reward" if self.monitor_val else "reward"
                    print(
                        f"{reward_type} increase from {self.best_reward} to {r}",
                        end=" ",
                    )
                self.best_reward = r
            else:
                return
        format_dict = {"iter": it, "reward": reward}
        if val_reward is not None:
            format_dict["val_reward"] = val_reward
        # print(self.filepath, format_dict)
        fpath = self.filepath.format(**format_dict)
        if self.verbose:
            print("Saving ES to", fpath)
        fdir = os.path.split(fpath)[0]
        if not os.path.exists(fdir):
            os.mkdir(fdir)
        with open(fpath, "wb") as file:
            # pickle.dump(es.weights, file)
            all_data = {
                "weights": es.weights,
                "population_size": es.POPULATION_SIZE,
                "sigma": es.SIGMA,
                "learning_rate": es.learning_rate,
                "decay": es.decay,
                "num_threads": es.num_threads,
            }
            pickle.dump(all_data, file)

class CSVLogger:
    def __init__(self, filename, separator=","):
        self.filename = filename
        self.separator = separator

    def __call__(self, es, it, reward, val_reward):
        try:
            open(self.filename).close()
        except FileNotFoundError:
            header = ["iter", "reward", "lr"]
            if val_reward is not None:
                header.append("val_reward")
            with open(self.filename, "w", newline="", encoding="utf-8-sig") as file:
                writer = csv.writer(file, delimiter=self.separator)
                writer.writerow(header)
        data = [it, reward, es.learning_rate]
        if val_reward is not None:
            data.append(val_reward)
        with open(self.filename, "a", newline="", encoding="utf-8-sig") as file:
            writer = csv.writer(file, delimiter=self.separator)
            writer.writerow(data)