sotetsuk / pgx

♟️ Vectorized RL game environments in JAX
http://sotets.uk/pgx/
Apache License 2.0
372 stars 23 forks source link

[Chess] Accelerate `chess_utils.py` import (jnp => np) #1170

Closed Akulen closed 4 months ago

Akulen commented 5 months ago

Importing _src/chess_utils.py currently takes between 10 and 15 seconds. It is mainly due to the fact that all pre-computations are made on jax.numpy array, which are really slow. This PR reduces the import time to ~1s, by using numpy array instead, and converting them to jax.numpy array at the end.

sotetsuk commented 5 months ago

Thank you for a PR! I'll check it in a few days.

Akulen commented 5 months ago

Did the same modification to gardner_chess. The time gain is less impressive, but it still goes from 3-4s to <1s

sotetsuk commented 4 months ago

Hi! I'm so sorry for late response 🙏 I locally confirmed it passes tests and actually improve speed! Great thanks! ❤️ CI failed in different part. I'll fix it in different PR. I'll merge this PR. Again thank you so much for your effort and sorry for late response 🙏