dcharatan / pixelsplat

[CVPR 2024 Oral, Best Paper Runner-Up] Code for "pixelSplat: 3D Gaussian Splats from Image Pairs for Scalable Generalizable 3D Reconstruction" by David Charatan, Sizhe Lester Li, Andrea Tagliasacchi, and Vincent Sitzmann
http://davidcharatan.com/pixelsplat/
MIT License
830 stars 56 forks source link

support `fp16` training #78

Closed thucz closed 3 months ago

thucz commented 3 months ago

Hi! I want to train pixelsplat on higher resolution images. But the GPU memory limits my thoughts. Do you know how to adapt your pipeline especially diff-gaussian-rasterization module to fp16 training?

dcharatan commented 3 months ago

You should be able to switch the neural network components to FP16 using the steps here, since the code is based on PyTorch Lightning. diff-gaussian-rasterization would be harder to convert to FP16 because you would have to change the CUDA code in a bunch of places and certain operations might not work with FP16 precision. I would start with the neural network parts of pixelSplat, since those are responsible for the vast majority of the memory usage.

thucz commented 3 months ago

In fact, I tried to train in precision=16. The diff-gaussian-rasterization module will raise errors about the data type.

image