nerfstudio-project / gsplat

CUDA accelerated rasterization of gaussian splatting
https://docs.gsplat.studio/
Apache License 2.0
2.2k stars 277 forks source link

JAX support? #175

Open yklcs opened 6 months ago

yklcs commented 6 months ago

I'm working on a preliminary port of gsplat to JAX. It appears like it'd be possible if I were to reuse the CUDA kernels (mostly) as-is and heavily modify the bindings. But it would also require substantial changes to the Python-side code and overall API. JAX's custom GPU FFI requires quite a bit of boilerplate.

I was wondering if there's any interest in merging JAX support into gsplat if I were to create a PR, or, even better, if there's a maintainer interested in collaborating to support JAX.

If not I'll just create a hard fork. Thanks.

yklcs commented 6 months ago

I ended up creating my own implementation based on gsplat here: https://github.com/yklcs/jaxsplat

brentyi commented 6 months ago

@yklcs this is cool work, are you planning to implement a full training pipeline in JAX?

This is something I'm very curious about, especially because splatfacto currently relies a lot on things like dynamic shapes and boolean masking (which are hard in JAX).

yklcs commented 6 months ago

Yes, dynamic shapes are a problem: as of right now JIT doesn't work. gaussian_ids_sorted is num_intersects long which is dynamic depending on num_tiles_hit. So a full pipeline would need to come after fixing that unless no JIT is acceptable.

I'm not sure what the best way of removing the dynamic shape is. num_intersects is bounded by num_tiles * num_points, which is probably too big to store. The tiling and binning approach may just be incompatible with statically known shapes. Maybe someone else has better ideas?

brentyi commented 6 months ago

Yeah, it's an interesting problem!

It seems hard to make this useful without JIT. For making the shape static, could a MAX_INTERSECTS or MAX_AVG_INTERSECTS_PER_GAUSSIAN constant be good enough? If the number of intersects exceeds the constant:

This could also be a feature and not a bug. Having num_intersects = num_tiles * num_points doesn't seem unreasonable (for example, if we have only large Gaussians), and choosing some well-defined behavior seems better than a spurious OOM.

yklcs commented 5 months ago

JIT now works with jaxsplat: I took the simple approach of simply recalculating gaussian_ids in the backwards pass. I'll see if there's a better approach later on, those ideas seem worth exploring.

Yeah, it's an interesting problem!

It seems hard to make this useful without JIT. For making the shape static, could a MAX_INTERSECTS or MAX_AVG_INTERSECTS_PER_GAUSSIAN constant be good enough? If the number of intersects exceeds the constant:

  • Maybe Gaussians can be prioritized based on distance or alpha, and any "overflow" can just be ignored?
  • It seems possible to reduce memory usage by trading for computation, perhaps the forward/backward can be done in multiple passes? After the Gaussians are sorted it seems possible to chunk them by distance, rasterize separately, and then alpha-composite?

This could also be a feature and not a bug. Having num_intersects = num_tiles * num_points doesn't seem unreasonable (for example, if we have only large Gaussians), and choosing some well-defined behavior seems better than a spurious OOM.