I am in the process of implementing 3D Gaussian Splatting's rasterization with Pallas, and one key part of the process is constructing a uint64 key for every Gaussian-Tile intersection (tile = 16x16 patch of pixels). Given a tile with idx i: uint32 and a Gaussian whose mean is distance d: float32 from the camera origin, the key is,
Keys will then be sorted and each tile will be processed independently. The operation d.view(..) -- equivalently jax.lax.bitcast_convert_type -- does not yet support lowering on GPU.
I am in the process of implementing 3D Gaussian Splatting's rasterization with Pallas, and one key part of the process is constructing a
uint64
key for every Gaussian-Tile intersection (tile = 16x16 patch of pixels). Given a tile with idxi: uint32
and a Gaussian whose mean is distanced: float32
from the camera origin, the key is,Keys will then be sorted and each tile will be processed independently. The operation
d.view(..)
-- equivalentlyjax.lax.bitcast_convert_type
-- does not yet support lowering on GPU.