jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.37k stars 2.78k forks source link

Unimplemented primitive in Pallas GPU lowering: bitcast_convert_type #24582

Open duckworthd opened 1 day ago

duckworthd commented 1 day ago

I am in the process of implementing 3D Gaussian Splatting's rasterization with Pallas, and one key part of the process is constructing a uint64 key for every Gaussian-Tile intersection (tile = 16x16 patch of pixels). Given a tile with idx i: uint32 and a Gaussian whose mean is distance d: float32 from the camera origin, the key is,

key = (jnp.astype(i, jnp.uint64) << 32) | d.view(jnp.uint32)

Keys will then be sorted and each tile will be processed independently. The operation d.view(..) -- equivalently jax.lax.bitcast_convert_type -- does not yet support lowering on GPU.

superbobry commented 1 day ago

Thanks for the bug report @duckworthd! This should be straightforward to fix.