RobertTLange / evosax

Evolution Strategies in JAX 🦎
Apache License 2.0
475 stars 44 forks source link

Integer Overflow for Large Dimension Problems #40

Closed twoletters closed 1 year ago

twoletters commented 1 year ago

num_dims is of type int32 and can overflow when squared:

https://github.com/RobertTLange/evosax/blob/3c7b751c434cf374516c261a8af367d492f5f859/evosax/strategies/cma_es.py#L65

Casting num_dims as a float fixes the issue:

        / (float(num_dims + 2) ** 2 + alpha_cov * mu_eff / 2),
RobertTLange commented 1 year ago

Hi @twoletters, Sorry for the super late response and thank you for raising this! you are right. For large network sizes/dimensions this will cause an overflow error. I have somehow fixed this in Sep_CMA_ES by fixing the value to a max of 40k in the development branch.

https://github.com/RobertTLange/evosax/blob/2f9ec7935108862e3c45a9413fffbf56c34d5341/evosax/strategies/sep_cma_es.py#L89

It is only used in computing hyperparameters/learning rates in CMA-style algorithms and from some experiments it seems to not hurt. I will add the same clipping to the other CMA variants.

Let me know if you encounter any other issues. Best wishes and again thank you, Rob

twoletters commented 1 year ago

Great! IIRC I found the error while using LM-MA-ES by using a dimension in the six digits, so it could be happening in other places. Amazingly, algos like LM-MA-ES and CR-FM-NES can handle and work with such large problems. It was great to be able to compare different approaches to see which ones worked best in my case. I am forever grateful for your work!

RobertTLange commented 1 year ago

Fixed for all relevant strategies in release v.0.1.4! Thank you for raising this.