google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
28.2k stars 2.58k forks source link

reintroduce the Threefry GPU kernel lowering, under a flag #21023

Closed copybara-service[bot] closed 2 weeks ago

copybara-service[bot] commented 2 weeks ago

reintroduce the Threefry GPU kernel lowering, under a flag

On GPU, the Threefry PRNG implementation no longer lowers to a kernel call by default. This choice can improve runtime memory usage at a compile-time cost. Prior behavior, which produces a kernel call, can be recovered with:

jax.config.update('jax_threefry_gpu_kernel_lowering', True)