LucasAlegre / morl-baselines

Multi-Objective Reinforcement Learning algorithms implementations.
https://lucasalegre.github.io/morl-baselines
MIT License
271 stars 44 forks source link

Implement fast pareto and convex hull pruning #60

Closed wilrop closed 1 year ago

wilrop commented 1 year ago

In this pull request, I add batched, fast and importantly correct versions of Pareto prune and Convex hull prune. Notably, many of the online examples are wrong. This implementation has been extensively tested, so should not have any further errors.

Concretely,

  1. I added a fast_c_prune algorithm that leverages the scipy implementation of a convex hull computation (which is very fast).
  2. I added an arg_p_prune which returns only the elements of the input array that should be kept.
  3. I added a fast_p_prune function that just calls the arg version and then selects the correct elements.

The way these functions work is very straightforward. Instead of doing lots of loops which are nice from a theoretical perspective, I simply compare everything against everything in a batched way. This may result in more computations than strictly necessary, but is extremely fast.

Note that the functions have an optional parameter that ensures duplicates are or are not removed by default.

As a final suggestion, it is quite easy to make a jax jitted version of the fast p prune, but this creates another dependency and may or may not be worth the extra overhead in the end for you. I'll leave it as an exercise to the reader 😉 (or you can let me know if you want that)

ffelten commented 1 year ago

Hello, thanks for this PR which will have a huge impact of the performances (especially the CCS pruning I think).

Could you consider:

  1. To use CCS pruning where fitted e.g. in most algorithms since we generally do weighted sum scalarization, e.g. PGMORL relies on the ParetoArchive, we could imagine an option for convexity in that class and rely on the CCS pruning. I think it would be a nice addition to also Pareto prune before log_all_multi_policy_metrics is called as computing the metrics can take a lot of time (hypervolume is np-hard).
  2. To move the implementations related to Pareto and pruning next to each others (in common/pareto.py or rename the file to something like common/pruning.py)

Obviously happy to help if needed.

wilrop commented 1 year ago

So I think your second point was done by Lucas but the first one is still open. However, I'm not completely sure what you would like me to do there. One thing I can do in the meanwhile is adding my tests that show that it actually works and also make the function take as input a list of numpy arrays rather than just an array (suggested by Lucas).

wilrop commented 1 year ago

Oh, I think I see now. You want me to add a parameter convex to the ParetoArchive that calls the CCS pruning algo instead of the PCS pruning one. I can do that! I'll also change the call to filter_pareto_dominated rather than get_non_dominated.

ffelten commented 1 year ago

FYI, before we merge this, I would like to tag the v1.0 on the main branch with the v1.0 version of MO-Gymnasium. So do not merge yet please @LucasAlegre

wilrop commented 1 year ago

Alright everyone, the above commit implements a bunch of tests, adds lists to the accepted inputs for the functions and extends the Pareto archive with an optional convex hull parameter. I have a couple of important notes below.

Benchmarking

I've done some additional benchmarking, and it appears that my original get_non_dominated function is still significantly faster than my new addition for the Pareto pruning. There is a reason and semi-solution for this. The reason is that get_non_dominated is a good mixture between the PPrune algorithm and leveraging batch operations. My new algorithm does nothing of that and just compares everything against everything simultaneously. The benefit of doing that is you can easily JIT compile it using, e.g., Jax, but it appears that if you do not do that the previous algorithm is still faster. I'm not sure you want to add an additional JAX dependency on the project so I'll leave that up to you

Pareto Archive

The code in pareto.py has lines 139-147 that I cannot really figure out what they are doing. If someone can check whether they are still necessary now, you might be able to get rid of them. My intuition is that the get_non_dominated method may mess up the order because it moves everything to Sets, but if you now call a function that operates on ndarrays you won't have that problem.

Tests

My tests are actually quite cool and might be useful in the future again so I'll briefly explain them. For the PF, I just sample a bunch of points uniformly from the unit d-dimensional ball using the Box-Muller Transform (see: https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform). I then create a bunch of dominated points by selecting points on the PF at random and detracting a random number from them. For the CH, I generate non-dominated points the same way but then construct the convex hull using Scipy. I then take the simplicial facets of the convex hull and generate points on these simplices to create points which are not Pareto dominated but are convex dominated!

Voila, I hope this will be useful for you all!

LucasAlegre commented 1 year ago

Thanks a lot again @wilrop ! :D

ffelten commented 1 year ago

@wilrop 👑

wilrop commented 1 year ago

It was my pleasure! I'll add some more stuff in the future ;)