nnaisense / evotorch

Advanced evolutionary computation library built directly on top of PyTorch, created at NNAISENSE.
https://evotorch.ai
Apache License 2.0
997 stars 62 forks source link

Improve MAPElites performance using torch_scatter #93

Open JakeF-Bitweave opened 10 months ago

JakeF-Bitweave commented 10 months ago

Would you be interested in contributions to re-work MAPElites to use torch_scatter rather than the vmaped extended_population x feature_grid operation?

The general gist is:

JakeF-Bitweave commented 10 months ago

Runnable comparison:

import math
import time
from typing import NamedTuple, List, Iterable

import torch
from torch_scatter import scatter_max, scatter_min
from evotorch import Problem
from evotorch.algorithms import MAPElites, SearchAlgorithm
from evotorch.algorithms.ga import ExtendedPopulationMixin
from evotorch.algorithms.searchalgorithm import SinglePopulationAlgorithmMixin
from evotorch.operators import GaussianMutation, SimulatedBinaryCrossOver

class FeatureGrid(NamedTuple):
    lower_bounds: List[float]
    upper_bounds: List[float]
    bins: List[int]

class MAPElitesScatter(MAPElites):
    def __init__(
        self,
        problem: Problem,
        *,
        operators: Iterable,
        feature_grid: FeatureGrid,
    ):
        problem.ensure_single_objective()
        problem.ensure_numeric()
        SearchAlgorithm.__init__(self, problem)
        self._sense = self._problem.senses[0]
        self._feature_grid = feature_grid
        self._popsize = math.prod(feature_grid.bins)
        self._population = problem.generate_batch(self._popsize)
        self._filled = torch.zeros(self._popsize, dtype=torch.bool, device=self._population.device)
        self._scatter_best = scatter_max if self._sense == "max" else scatter_min
        ExtendedPopulationMixin.__init__(
            self,
            re_evaluate=True,
            re_evaluate_parents_first=None,
            operators=operators,
            allow_empty_operators_list=False,
        )
        SinglePopulationAlgorithmMixin.__init__(self)

    def _step(self):
        # Form an extended population from the parents and from the children
        extended_population = self._make_extended_population(split=False)
        extended_pop_size = extended_population.eval_shape[0]

        all_evals = extended_population.evals.as_subclass(torch.Tensor)
        all_values = extended_population.values.as_subclass(torch.Tensor)
        all_fitnesses = all_evals[:, 0]
        feats = all_evals[:, 1:]
        device = all_evals.device

        hypervolume_index = torch.zeros(extended_pop_size, device=device, dtype=torch.long)
        widths = []
        for i, (lb, ub, n_bins) in enumerate(zip(*self._feature_grid)):
            diff = ub - lb
            const = n_bins / diff
            min_ = const * lb
            max_ = (const * ub) - 1

            feat = feats[:, i]

            feat *= const
            feat = torch.clamp_min(feat, min_)
            feat = torch.clamp_max(feat, max_)
            feat -= min_

            hypervolume_index += (feat.long() * math.prod(widths))
            widths.append(n_bins)

        # Find the best population members for each hypervolume
        _, argbest = self._scatter_best(all_fitnesses, hypervolume_index)

        # Filter hypervolumes that had no members
        all_index = argbest[argbest < extended_pop_size]
        index = torch.argwhere(argbest < extended_pop_size)[:, 0]

        # Build empty output
        values = torch.zeros((self._popsize, all_values.shape[1]), device=device, dtype=all_values.dtype)
        evals = torch.zeros((self._popsize, all_evals.shape[1]), device=device, dtype=all_evals.dtype)
        suitable = torch.zeros(self._popsize, device=device, dtype=torch.bool)

        # Map the members from the extended population to the output
        values[index] = all_values[all_index]
        evals[index] = all_evals[all_index]
        suitable[index] = True

        # Place the most suitable decision values and evaluation results into the current population.
        self._population.access_values(keep_evals=True)[:] = values
        self._population.access_evals()[:] = evals

        # If there was a suitable solution for the i-th cell, fill[i] is to be set as True.
        self._filled[:] = suitable

def kursawe(x: torch.Tensor) -> torch.Tensor:
    f1 = torch.sum(
        -10 * torch.exp(
            -0.2 * torch.sqrt(x[:, 0:2] ** 2.0 + x[:, 1:3] ** 2.0)
        ),
        dim=-1,
    )
    f2 = torch.sum(
        (torch.abs(x) ** 0.8) + (5 * torch.sin(x ** 3)),
        dim=-1,
    )
    fitnesses = torch.stack([f1 + f2, f1, f2], dim=-1)
    return fitnesses

if __name__ == "__main__":
    tensor_feature_grid = MAPElites.make_feature_grid(
        lower_bounds=[-20, -14],
        upper_bounds=[-10, 4],
        num_bins=50,
        dtype="float32",
    )

    for clazz, feature_grid in [
        (MAPElitesScatter, FeatureGrid([-20, -14], [-10, 4], [50, 50])),
        (MAPElites, tensor_feature_grid),
    ]:
        problem = Problem(
            "min",
            kursawe,
            solution_length=3,
            eval_data_length=2,
            bounds=(-5.0, 5.0),
            vectorized=True,
        )
        searcher = clazz(
            problem,
            feature_grid=feature_grid,
            operators=[
                SimulatedBinaryCrossOver(problem, tournament_size=4, cross_over_rate=1.0, eta=8),
                GaussianMutation(problem, stdev=0.03),
            ],
        )
        start = time.time()
        searcher.run(100)
        print("Final status:\n", searcher.status)
        print("Impl: ", clazz)
        print("Time spent (secs): ", time.time() - start)
        print("Filled hypervolumes: ", searcher.filled.sum())

out:

[2023-10-19 15:23:41] INFO     <46962> evotorch.core: Instance of `Problem` (id:6097491664) -- The `dtype` for the problem's decision variables is set as torch.float32
[2023-10-19 15:23:41] INFO     <46962> evotorch.core: Instance of `Problem` (id:6097491664) -- `eval_dtype` (the dtype of the fitnesses and evaluation data) is set as torch.float32
[2023-10-19 15:23:41] INFO     <46962> evotorch.core: Instance of `Problem` (id:6097491664) -- The `device` of the problem is set as cpu
[2023-10-19 15:23:41] INFO     <46962> evotorch.core: Instance of `Problem` (id:6097491664) -- The number of actors that will be allocated for parallelized evaluation is 0
Final status:
 <LazyStatusDict
    pop_best = <not yet computed>
    pop_best_eval = <not yet computed>
    mean_eval = <not yet computed>
    median_eval = <not yet computed>
    iter = 100
    best = <Solution values=tensor([-1.1392, -1.1283, -1.1402]), evals=tensor([-26.1042, -14.5122, -11.5920])>
    worst = <Solution values=tensor([ 4.7636, -4.5227, -4.6175]), evals=tensor([18.8853, -5.4335, 24.3187])>
    best_eval = -26.104228973388672
    worst_eval = 18.88526725769043
>
Impl:  <class '__main__.MAPElitesScatter'>
Time spent (secs):  0.2477729320526123
Filled hypervolumes:  ReadOnlyTensor(1562)
[2023-10-19 15:23:41] INFO     <46962> evotorch.core: Instance of `Problem` (id:5351110928) -- The `dtype` for the problem's decision variables is set as torch.float32
[2023-10-19 15:23:41] INFO     <46962> evotorch.core: Instance of `Problem` (id:5351110928) -- `eval_dtype` (the dtype of the fitnesses and evaluation data) is set as torch.float32
[2023-10-19 15:23:41] INFO     <46962> evotorch.core: Instance of `Problem` (id:5351110928) -- The `device` of the problem is set as cpu
[2023-10-19 15:23:41] INFO     <46962> evotorch.core: Instance of `Problem` (id:5351110928) -- The number of actors that will be allocated for parallelized evaluation is 0
Final status:
 <LazyStatusDict
    pop_best = <not yet computed>
    pop_best_eval = <not yet computed>
    mean_eval = <not yet computed>
    median_eval = <not yet computed>
    iter = 100
    best = <Solution values=tensor([-1.1402, -1.1233, -1.1448]), evals=tensor([-26.1034, -14.5165, -11.5869])>
    worst = <Solution values=tensor([ 4.4850, -4.8089, -4.8177]), evals=tensor([18.5221, -5.2474, 23.7694])>
    best_eval = -26.103425979614258
    worst_eval = 18.5220890045166
>
Impl:  <class 'evotorch.algorithms.mapelites.MAPElites'>
Time spent (secs):  8.833541870117188
Filled hypervolumes:  ReadOnlyTensor(1522)