sotetsuk / pgx

♟️ Vectorized RL game environments in JAX
http://sotets.uk/pgx/
Apache License 2.0
394 stars 25 forks source link

[Mahjong] Reduce compilation/run time #1070

Open OkanoShinri opened 11 months ago

OkanoShinri commented 11 months ago
benchmark.py ```py from pgx._mahjong._mahjong2 import ( Mahjong, _discard, _selfkan, _riichi, _tsumo, _ron, _pon, _minkan, _pass, ) import jax import time import sys # func(state, action) functions1 = {"_discard": _discard, "_selfkan": _selfkan} # func(state) functions2 = { "_riichi": _riichi, "_tsumo": _tsumo, "_ron": _ron, "_pon": _pon, "_minkan": _minkan, "_pass": _pass, } env = Mahjong() func_name = sys.argv[1] if func_name in functions1: func = functions1[func_name] key = jax.random.PRNGKey(352) state = env.init(key=key) time_sta = time.perf_counter() jax.jit(func)(state, 0) time_end = time.perf_counter() delta = (time_end - time_sta) * 1000 exp = jax.make_jaxpr(func)(state, 0) n_line = len(str(exp).split("\n")) print(f"| `{func.__name__}` | {n_line} | {delta:.1f}ms |") elif func_name in functions2: func = functions2[func_name] key = jax.random.PRNGKey(352) state = env.init(key=key) time_sta = time.perf_counter() jax.jit(func)(state) time_end = time.perf_counter() delta = (time_end - time_sta) * 1000 exp = jax.make_jaxpr(func)(state) n_line = len(str(exp).split("\n")) print(f"| `{func.__name__}` | {n_line} | {delta:.1f}ms |") ```
benchmark.sh ```py echo "| function name | # expr lines | compile time |" echo "| :--- | ---: | ---: |" for funcname in _discard _selfkan _riichi _tsumo _ron _pon _minkan _pass do python3 benchmark.py $funcname done ```
function name # expr lines compile time
_discard 12245 4594.9ms
_selfkan 333 202.9ms
_riichi 934 447.8ms
_tsumo 2328 1114.8ms
_ron 2280 1024.9ms
_pon 261 112.6ms
_minkan 324 122.2ms
_pass 8539 2385.0ms
OkanoShinri commented 11 months ago
benchmark.py ```py from pgx.mahjong._env import ( Mahjong, _discard, _selfkan, _riichi, _tsumo, _ron, _pon, _minkan, _pass, ) import jax import time import sys import timeit # func(state, action) functions1 = {"_discard": _discard, "_selfkan": _selfkan} # func(state) functions2 = { "_riichi": _riichi, "_tsumo": _tsumo, "_ron": _ron, "_pon": _pon, "_minkan": _minkan, "_pass": _pass, } env = Mahjong() N = 10 func_name = sys.argv[1] if func_name in functions1: func = functions1[func_name] key = jax.random.PRNGKey(352) state = env.init(key=key) time_sta = time.perf_counter() jax.jit(func)(state, 0) time_end = time.perf_counter() delta = (time_end - time_sta) * 1000 exp = jax.make_jaxpr(func)(state, 0) n_line = len(str(exp).split("\n")) jit_func = jax.jit(func) run_delta = timeit.timeit( "jit_func(state, 0)", globals=globals(), number=N ) print( f"| `{func.__name__}` | {n_line} | {delta:.1f}ms | {run_delta/N*1000000:.1f}μs |" ) elif func_name in functions2: func = functions2[func_name] key = jax.random.PRNGKey(352) state = env.init(key=key) time_sta = time.perf_counter() jax.jit(func)(state) time_end = time.perf_counter() delta = (time_end - time_sta) * 1000 exp = jax.make_jaxpr(func)(state) n_line = len(str(exp).split("\n")) jit_func = jax.jit(func) run_delta = timeit.timeit("jit_func(state)", globals=globals(), number=N) print( f"| `{func.__name__}` | {n_line} | {delta:.1f}ms | {run_delta/N*1000000:.1f}μs |" ) ```
benchmark.sh ```py echo "| function name | # expr lines | compile time | running time |" echo "| :--- | ---: | ---: | ---: |" for funcname in _discard _selfkan _riichi _tsumo _ron _pon _minkan _pass do python3 benchmark.py $funcname done ```
function name # expr lines compile time running time
_discard 12245 5012.2ms 264.3μs
- _draw 5517 2137.6ms 244.9μs
-- _make_legal_action_mask 3048 1281.3ms 106.2μs
-- _make_legal_action_mask_w_riichi 2394 1088.4ms 78.7μs
_selfkan 333 198.3ms 64.7μs
_riichi 934 427.5ms 104.6μs
_tsumo 2328 1095.3ms 77.1μs
_ron 2280 1048.6ms 104.6μs
_pon 261 117.2ms 64.7μs
_minkan 324 133.0ms 69.7μs
_pass 8539 2526.0ms 223.7μs
can_tsumo 225 115.3ms 12.5μs
can_ron 245 122.4ms 10.6μs
can_minkan 9 12.9ms 8.6μs
can_chi 206 63.3ms 11.7μs
can_riichi 424 289.8ms 121.8μs
is_tenpai 344 189.0ms 11.9μs
sotetsuk commented 11 months ago
shogi vs mahjong ``` {"game": "shogi", "library": "pgx/1dev", "total_steps": 200, "total_sec": 0.08473587036132812, "steps/sec": 2360.2755143384506, "batch_size": 2, "pgx.__version__": "2.0.0"} {"game": "mahjong", "library": "pgx/1dev", "total_steps": 200, "total_sec": 2.7215688228607178, "steps/sec": 73.48702642388972, "batch_size": 2, "pgx.__version__": "2.0.0"} {"game": "shogi", "library": "pgx/1dev", "total_steps": 400, "total_sec": 0.08438754081726074, "steps/sec": 4740.036220112051, "batch_size": 4, "pgx.__version__": "2.0.0"} {"game": "mahjong", "library": "pgx/1dev", "total_steps": 400, "total_sec": 2.679060459136963, "steps/sec": 149.30607431265537, "batch_size": 4, "pgx.__version__": "2.0.0"} {"game": "shogi", "library": "pgx/1dev", "total_steps": 800, "total_sec": 0.08239483833312988, "steps/sec": 9709.346072936349, "batch_size": 8, "pgx.__version__": "2.0.0"} {"game": "mahjong", "library": "pgx/1dev", "total_steps": 800, "total_sec": 2.8103582859039307, "steps/sec": 284.6612134874774, "batch_size": 8, "pgx.__version__": "2.0.0"} {"game": "shogi", "library": "pgx/1dev", "total_steps": 1600, "total_sec": 0.10672569274902344, "steps/sec": 14991.70404789563, "batch_size": 16, "pgx.__version__": "2.0.0"} {"game": "mahjong", "library": "pgx/1dev", "total_steps": 1600, "total_sec": 2.893850564956665, "steps/sec": 552.896552218466, "batch_size": 16, "pgx.__version__": "2.0.0"} {"game": "shogi", "library": "pgx/1dev", "total_steps": 3200, "total_sec": 0.1228034496307373, "steps/sec": 26057.899917487746, "batch_size": 32, "pgx.__version__": "2.0.0"} {"game": "mahjong", "library": "pgx/1dev", "total_steps": 3200, "total_sec": 2.9354007244110107, "steps/sec": 1090.1407679668953, "batch_size": 32, "pgx.__version__": "2.0.0"} {"game": "shogi", "library": "pgx/1dev", "total_steps": 6400, "total_sec": 0.14135956764221191, "steps/sec": 45274.614988910536, "batch_size": 64, "pgx.__version__": "2.0.0"} {"game": "mahjong", "library": "pgx/1dev", "total_steps": 6400, "total_sec": 3.0983378887176514, "steps/sec": 2065.6236439883096, "batch_size": 64, "pgx.__version__": "2.0.0"} {"game": "shogi", "library": "pgx/1dev", "total_steps": 12800, "total_sec": 0.19550633430480957, "steps/sec": 65471.024483758185, "batch_size": 128, "pgx.__version__": "2.0.0"} {"game": "mahjong", "library": "pgx/1dev", "total_steps": 12800, "total_sec": 3.35774564743042, "steps/sec": 3812.0814808576847, "batch_size": 128, "pgx.__version__": "2.0.0"} {"game": "shogi", "library": "pgx/1dev", "total_steps": 25600, "total_sec": 0.35773587226867676, "steps/sec": 71561.1767912757, "batch_size": 256, "pgx.__version__": "2.0.0"} {"game": "mahjong", "library": "pgx/1dev", "total_steps": 25600, "total_sec": 4.034330606460571, "steps/sec": 6345.538454137644, "batch_size": 256, "pgx.__version__": "2.0.0"} {"game": "shogi", "library": "pgx/1dev", "total_steps": 51200, "total_sec": 0.6315653324127197, "steps/sec": 81068.41425954246, "batch_size": 512, "pgx.__version__": "2.0.0"} {"game": "mahjong", "library": "pgx/1dev", "total_steps": 51200, "total_sec": 4.863647699356079, "steps/sec": 10527.078268185132, "batch_size": 512, "pgx.__version__": "2.0.0"} {"game": "shogi", "library": "pgx/1dev", "total_steps": 102400, "total_sec": 1.1985838413238525, "steps/sec": 85434.15693548627, "batch_size": 1024, "pgx.__version__": "2.0.0"} {"game": "mahjong", "library": "pgx/1dev", "total_steps": 102400, "total_sec": 7.006359338760376, "steps/sec": 14615.293770832695, "batch_size": 1024, "pgx.__version__": "2.0.0"} ```
image