Open yklcs opened 6 months ago
I ended up creating my own implementation based on gsplat here: https://github.com/yklcs/jaxsplat
@yklcs this is cool work, are you planning to implement a full training pipeline in JAX?
This is something I'm very curious about, especially because splatfacto
currently relies a lot on things like dynamic shapes and boolean masking (which are hard in JAX).
Yes, dynamic shapes are a problem: as of right now JIT doesn't work. gaussian_ids_sorted
is num_intersects
long which is dynamic depending on num_tiles_hit
.
So a full pipeline would need to come after fixing that unless no JIT is acceptable.
I'm not sure what the best way of removing the dynamic shape is.
num_intersects
is bounded by num_tiles * num_points
, which is probably too big to store.
The tiling and binning approach may just be incompatible with statically known shapes.
Maybe someone else has better ideas?
Yeah, it's an interesting problem!
It seems hard to make this useful without JIT. For making the shape static, could a MAX_INTERSECTS
or MAX_AVG_INTERSECTS_PER_GAUSSIAN
constant be good enough? If the number of intersects exceeds the constant:
This could also be a feature and not a bug. Having num_intersects = num_tiles * num_points
doesn't seem unreasonable (for example, if we have only large Gaussians), and choosing some well-defined behavior seems better than a spurious OOM.
JIT now works with jaxsplat: I took the simple approach of simply recalculating gaussian_ids
in the backwards pass.
I'll see if there's a better approach later on, those ideas seem worth exploring.
Yeah, it's an interesting problem!
It seems hard to make this useful without JIT. For making the shape static, could a
MAX_INTERSECTS
orMAX_AVG_INTERSECTS_PER_GAUSSIAN
constant be good enough? If the number of intersects exceeds the constant:
- Maybe Gaussians can be prioritized based on distance or alpha, and any "overflow" can just be ignored?
- It seems possible to reduce memory usage by trading for computation, perhaps the forward/backward can be done in multiple passes? After the Gaussians are sorted it seems possible to chunk them by distance, rasterize separately, and then alpha-composite?
This could also be a feature and not a bug. Having
num_intersects = num_tiles * num_points
doesn't seem unreasonable (for example, if we have only large Gaussians), and choosing some well-defined behavior seems better than a spurious OOM.
I'm working on a preliminary port of gsplat to JAX. It appears like it'd be possible if I were to reuse the CUDA kernels (mostly) as-is and heavily modify the bindings. But it would also require substantial changes to the Python-side code and overall API. JAX's custom GPU FFI requires quite a bit of boilerplate.
I was wondering if there's any interest in merging JAX support into gsplat if I were to create a PR, or, even better, if there's a maintainer interested in collaborating to support JAX.
If not I'll just create a hard fork. Thanks.