icaros-usc / pyribs

A bare-bones Python library for quality diversity optimization.
https://pyribs.org
MIT License
205 stars 31 forks source link

Speed up 2D cvt_archive_heatmap by order of magnitude #355

Closed btjanaka closed 10 months ago

btjanaka commented 10 months ago

Description

Currently, cvt_archive_heatmap plots individual polygons via ax.fill . We can speed this up by instead using a PolyCollection to add all the polygons at once. This is similar to using a PatchCollection as shown here: https://matplotlib.org/stable/gallery/shapes_and_collections/patch_collection.html.

Benchmark for plotting a CVTArchive with 10,000 cells:

I used the following code to benchmark the implementation:

"""Driver for cvt heatmap experiments."""

import time

import fire
import matplotlib.pyplot as plt
import numpy as np

from ribs.archives import CVTArchive
from ribs.visualize import cvt_archive_heatmap

def main(n_cells=10000):
    """Creates the archive and plots it."""
    np.random.seed(42)

    archive = CVTArchive(
        solution_dim=3,
        cells=n_cells,
        ranges=[(-1, 1), (-1, 1)],
        custom_centroids=np.random.uniform(-1, 1, (n_cells, 2)),
    )

    archive.add(
        np.random.uniform(-1, 1, (20000, 3)),
        np.random.standard_normal(20000),
        np.random.uniform(-1, 1, (20000, 2)),
    )

    plt.figure(figsize=(8, 6))

    start_time = time.time()
    cvt_archive_heatmap(archive)
    print("Plot time", time.time() - start_time)

    plt.savefig("cvt.png")

if __name__ == "__main__":
    fire.Fire(main)

TODO

Questions

Status