dcharatan / pixelsplat

[CVPR 2024 Oral, Best Paper Runner-Up] Code for "pixelSplat: 3D Gaussian Splats from Image Pairs for Scalable Generalizable 3D Reconstruction" by David Charatan, Sizhe Lester Li, Andrea Tagliasacchi, and Vincent Sitzmann
http://davidcharatan.com/pixelsplat/
MIT License
899 stars 62 forks source link

Weird Error: Out of memory error when running on V100 GPUs with the smaller batch #8

Closed thucz closed 8 months ago

thucz commented 10 months ago

Hi! I met an error that I cannot understand. I can run the code on A10 (22G) and A100(40G) with a smaller batch size. But I cannot run it on V100(32G).

The error is weird:

Error executing job with overrides: ['+experiment=re10k']                                                            
Traceback (most recent call last):
  File "/group/30042/ozhengchen/scene_gen/pixelsplat/src/main.py", line 123, in train                                                                                                                                                     
    trainer.fit(model_wrapper, datamodule=data_module, ckpt_path=checkpoint_path)
  File "/group/30042/ozhengchen/ft_local/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 544, in fit                                                                                   
    call._call_and_handle_interrupt(    
  File "/group/30042/ozhengchen/ft_local/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt                                                                
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/group/30042/ozhengchen/ft_local/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 102, in launch                                                         
    return function(*args, **kwargs) 
  File "/group/30042/ozhengchen/ft_local/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 580, in _fit_impl                                                                             
    self._run(model, ckpt_path=ckpt_path)                      
  File "/group/30042/ozhengchen/ft_local/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 989, in _run                                                                                  
    results = self._run_stage()                                                                     
  File "/group/30042/ozhengchen/ft_local/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1033, in _run_stage                                                                                self._run_sanity_check()                                                                                                                                                                                                               
  File "/group/30042/ozhengchen/ft_local/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1062, in _run_sanity_check
    val_loop.run()                                                         
File "/group/30042/ozhengchen/ft_local/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/pytorch_lightning/loops/utilities.py", line 182, in _decorator                                                                            
    return loop_run(self, *args, **kwargs)           
  File "/group/30042/ozhengchen/ft_local/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 134, in run                                                                             
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "/group/30042/ozhengchen/ft_local/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 391, in _evaluation_step                                                                
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
  File "/group/30042/ozhengchen/ft_local/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 309, in _call_strategy_hook                                                                      
    output = fn(*args, **kwargs)                     
  File "/group/30042/ozhengchen/ft_local/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 402, in validation_step                                                                   
    return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs)
  File "/group/30042/ozhengchen/ft_local/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 633, in __call__                                                                          
    wrapper_output = wrapper_module(*args, **kwargs)                                                
  File "/group/30042/ozhengchen/ft_local/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl                                                                             
    return self._call_impl(*args, **kwargs)            
  File "/group/30042/ozhengchen/ft_local/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl                                                                                     
    return forward_call(*args, **kwargs)               
  File "/group/30042/ozhengchen/ft_local/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1519, in forward                                                                                  
    else self._run_ddp_forward(*inputs, **kwargs)      
  File "/group/30042/ozhengchen/ft_local/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1355, in _run_ddp_forward                                                                         
    return self.module(*inputs, **kwargs)  # type: ignore[index]
  File "/group/30042/ozhengchen/ft_local/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl                                                                             
    return self._call_impl(*args, **kwargs)
  File "/group/30042/ozhengchen/ft_local/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl                                                                                     
    return forward_call(*args, **kwargs)         
  File "/group/30042/ozhengchen/ft_local/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 626, in wrapped_forward                                                                   
    out = method(*_args, **_kwargs)   
  File "/group/30042/ozhengchen/ft_local/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning_utilities/core/rank_zero.py", line 43, in wrapped_fn                                                                            
    return fn(*args, **kwargs)                   
  File "/group/30042/ozhengchen/scene_gen/pixelsplat/./src/model/model_wrapper.py", line 212, in validation_step
    output_probabilistic = self.decoder.forward(
  File "/group/30042/ozhengchen/ft_local/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 409, in wrapped_fn                                                                                         
    out = fn(*args, **kwargs)                   
  File "/group/30042/ozhengchen/scene_gen/pixelsplat/./src/model/decoder/decoder_splatting_cuda.py", line 46, in forward                                                                                                                   
    color = render_cuda(                                                                                                                                          
  File "/group/30042/ozhengchen/ft_local/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 409, in wrapped_fn                                                                                         
    out = fn(*args, **kwargs)                                        
  File "/group/30042/ozhengchen/scene_gen/pixelsplat/./src/model/decoder/cuda_splatting.py", line 117, in render_cuda
    image, radii = rasterizer(    
  File "/group/30042/ozhengchen/ft_local/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl                                                                             
    return self._call_impl(*args, **kwargs)                                      
  File "/group/30042/ozhengchen/ft_local/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl                                                                                     
    return forward_call(*args, **kwargs)
  File "/group/30042/ozhengchen/ft_local/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/diff_gaussian_rasterization/__init__.py", line 210, in forward                                                                            
    return rasterize_gaussians(                                                          
  File "/group/30042/ozhengchen/ft_local/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/diff_gaussian_rasterization/__init__.py", line 32, in rasterize_gaussians                                                                 
    return _RasterizeGaussians.apply(
  File "/group/30042/ozhengchen/ft_local/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply                                                                                           
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/group/30042/ozhengchen/ft_local/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/diff_gaussian_rasterization/__init__.py", line 92, in forward                                                                             
    num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 131071.75 GiB. GPU 0 has a total capacty of 31.75 GiB of which 27.95 GiB is free. Process 88117 has 3.79 GiB memory in use. Of the allocated memory 1.66 GiB is allocated by PyTorch, and 340.91 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

It amazingly shows that 131071.75 GiB is needed. I do not know what bug it is when running the code on V100 GPU.

dcharatan commented 10 months ago

What's the exact configuration (command line overrides and modifications to the config files) you're using? Does any other configuration work on the V100?

thucz commented 10 months ago

What's the exact configuration (command line overrides and modifications to the config files) you're using? Does any other configuration work on the V100?

I only modified the batch size to 2 (or 1) and the dataset path to my own path. The command line is unchanged: python src/main.py +experiment=re10k. I currently cannot figure out an appropriate configuration for V100.

# @package _global_

defaults:
  - override /dataset: re10k
  - override /model/encoder: epipolar
  - override /model/encoder/backbone: dino
  - override /loss: [mse, lpips]

wandb:
  name: re10k
  tags: [re10k, 256x256]

dataset:
  image_shape: [256, 256]
  roots: [/group/30042/ozhengchen/scene_gen/pixelSplat_data/re10k_subset]
# datasets/re10k

data_loader:
  train:
    batch_size: 2 #7

trainer:
  max_steps: 300_001
dcharatan commented 10 months ago

In case it's an issue with the environment, here's the exact environment (output of pip freeze) I used:

aiohttp==3.8.5
aiosignal==1.3.1
antlr4-python3-runtime==4.9.3
appdirs==1.4.4
async-timeout==4.0.3
attrs==23.1.0
beartype==0.15.0
beautifulsoup4==4.12.2
black==23.7.0
certifi==2022.12.7
charset-normalizer==2.1.1
click==8.1.7
colorama==0.4.6
colorspacious==1.1.2
contourpy==1.1.0
cycler==0.11.0
dacite==1.8.1
decorator==4.4.2
diff-gaussian-rasterization @ git+https://github.com/dcharatan/diff-gaussian-rasterization-modified.git@ec3c8ee5a50296550b38db48eeb8fcdf1c540900
docker-pycreds==0.4.0
e3nn==0.5.1
einops==0.6.1
filelock==3.9.0
fonttools==4.42.1
frozenlist==1.4.0
fsspec==2023.4.0
gdown==4.7.1
gitdb==4.0.10
GitPython==3.1.32
huggingface-hub==0.17.2
hydra-core==1.3.2
idna==3.4
imageio==2.31.1
imageio-ffmpeg==0.4.8
jaxtyping==0.2.21
Jinja2==3.1.2
kiwisolver==1.4.4
lazy_loader==0.3
lightning-utilities==0.9.0
lines==0.0.0
lpips==0.1.4
lxml==4.9.3
MarkupSafe==2.1.2
matplotlib==3.9.0.dev0
moviepy==1.0.3
mpmath==1.2.1
multidict==6.0.4
mypy-extensions==1.0.0
networkx==3.0rc1
numpy==1.24.1
omegaconf==2.3.0
opt-einsum==3.3.0
opt-einsum-fx==0.1.4
packaging==23.1
pathspec==0.11.2
pathtools==0.1.2
Pillow==10.1.0
platformdirs==3.10.0
plyfile==1.0.1
proglog==0.1.10
protobuf==4.24.1
psutil==5.9.5
pyparsing==3.0.9
PySocks==1.7.1
python-dateutil==2.8.2
pytorch-lightning==2.0.7
pytorch-triton==2.1.0+e6216047b8
PyWavelets==1.4.1
PyYAML==6.0.1
requests==2.28.1
ruff==0.0.285
safetensors==0.3.3
scikit-image==0.21.0
scipy==1.11.2
sentry-sdk==1.29.2
setproctitle==1.3.2
six==1.16.0
smmap==5.0.0
soupsieve==2.5
svg.py==1.4.2
svgutils==0.3.4
sympy==1.11.1
tabulate==0.9.0
tifffile==2023.8.12
timm==0.9.7
tomli==2.0.1
torch==2.1.0.dev20230820+cu121
torchaudio==2.1.0.dev20230821+cu121
torchmetrics==1.0.3
torchvision==0.16.0.dev20230821+cu121
tqdm==4.66.1
trimesh==3.23.5
typeguard==4.1.2
typing_extensions==4.7.1
urllib3==1.26.13
wandb==0.15.8
yarl==1.9.2
thucz commented 9 months ago

I think the problem is in the compiled https://github.com/dcharatan/diff-gaussian-rasterization-modified which should be also compiled on V100 GPU. Have you ever tried to run your code on a V100 card?

When I use the original diff-gaussian-rasterization, the error disappears. I wonder what differences there are between your modified diff-gaussian-rasterization and the original one.

dcharatan commented 9 months ago

You can see the differences here. As far as I can tell, there are no modifications that would cause the memory usage to explode. Are you compiling the original and the modified version the same way?

kevinYitshak commented 9 months ago

Hi, I was also facing the same issue even with A100 with 80GB memory.

Changing to the original diff-gaussian-rasterization the error disappears. Is this the right fix for it? I compiled both the rasterizer in the same way.

thucz commented 9 months ago

I compiled both rasterizers in the same way. The OOM error only appears in the modified one.

Pixie8888 commented 9 months ago

Hi I also face the same problem of OOM. I use pytorch 1.8. Do you know how to solve it? @thucz @kevinYitshak

dcharatan commented 9 months ago

@Pixie8888 @thucz @kevinYitshak Could you please post the exact environment (pip/conda environment, Python version, OS, etc.) that you're using? I can try to look into the issue some more.

Pixie8888 commented 9 months ago

Hi @dcharatan , below is the environment I use.

Package                     Version         Editable project location
--------------------------- --------------- -------------------------
absl-py                     2.1.0
addict                      2.4.0
antlr4-python3-runtime      4.9.3
anyio                       4.2.0
appdirs                     1.4.4
argon2-cffi                 23.1.0
argon2-cffi-bindings        21.2.0
arrow                       1.3.0
asttokens                   2.4.1
async-lru                   2.0.4
attrs                       23.2.0
Babel                       2.14.0
backcall                    0.2.0
beautifulsoup4              4.12.3
black                       24.1.1
bleach                      6.1.0
cachetools                  4.2.4
certifi                     2024.2.2
cffi                        1.16.0
charset-normalizer          3.3.2
click                       8.1.7
colorama                    0.4.6
colorspacious               1.1.2
comm                        0.2.1
contourpy                   1.1.1
cycler                      0.12.1
Cython                      0.29.33
dacite                      1.8.1
debugpy                     1.8.1
decorator                   5.1.1
defusedxml                  0.7.1
descartes                   1.1.0
diff-gaussian-rasterization 0.0.0
docker-pycreds              0.4.0
e3nn                        0.5.1
einops                      0.7.0
exceptiongroup              1.2.0
executing                   2.0.1
fastjsonschema              2.19.1
fire                        0.5.0
flake8                      7.0.0
fonttools                   4.48.1
fqdn                        1.5.1
future                      0.18.3
gitdb                       4.0.11
GitPython                   3.1.41
google-auth                 1.35.0
google-auth-oauthlib        0.4.6
grpcio                      1.48.2
h11                         0.14.0
httpcore                    1.0.2
httpx                       0.26.0
idna                        3.6
imageio                     2.33.1
importlib-metadata          7.0.1
importlib-resources         6.1.1
iniconfig                   2.0.0
ipykernel                   6.29.2
ipython                     8.12.3
ipywidgets                  8.1.2
isoduration                 20.11.0
jaxtyping                   0.2.19
jedi                        0.19.1
Jinja2                      3.1.3
joblib                      1.3.2
json5                       0.9.14
jsonpointer                 2.4
jsonschema                  4.21.1
jsonschema-specifications   2023.12.1
jupyter                     1.0.0
jupyter_client              8.6.0
jupyter-console             6.6.3
jupyter_core                5.7.1
jupyter-events              0.9.0
jupyter-lsp                 2.2.2
jupyter_server              2.12.5
jupyter_server_terminals    0.5.2
jupyterlab                  4.1.0
jupyterlab_pygments         0.3.0
jupyterlab_server           2.25.2
jupyterlab_widgets          3.0.10
kiwisolver                  1.4.5
llvmlite                    0.31.0
lyft-dataset-sdk            0.0.8
Markdown                    3.5.2
MarkupSafe                  2.1.5
matplotlib                  3.7.4
matplotlib-inline           0.1.6
mccabe                      0.7.0
mistune                     3.0.2
mkl-fft                     1.3.1
mkl-random                  1.2.2
mkl-service                 2.4.0
mmcv-full                   1.2.7
mmdet                       2.11.0
mmdet3d                     0.8.0           /data/ytxu/nerf/my_exp
mmpycocotools               12.0.3
mpmath                      1.3.0
mypy-extensions             1.0.0
nbclient                    0.9.0
nbconvert                   7.16.0
nbformat                    5.9.2
nest-asyncio                1.6.0
networkx                    2.2
notebook                    7.0.7
notebook_shim               0.2.3
numba                       0.48.0
numpy                       1.23.1
nuscenes-devkit             1.1.2
oauthlib                    3.2.2
olefile                     0.47
omegaconf                   2.3.0
opencv-python               4.9.0.80
opt-einsum                  3.3.0
opt-einsum-fx               0.1.4
overrides                   7.7.0
packaging                   23.2
pandas                      2.0.3
pandocfilters               1.5.1
parso                       0.8.3
pathspec                    0.12.1
pexpect                     4.9.0
pickleshare                 0.7.5
pillow                      10.2.0
pip                         23.3.1
pkgutil_resolve_name        1.3.10
platformdirs                4.2.0
plotly                      5.18.0
pluggy                      1.4.0
plyfile                     0.7.3
prometheus-client           0.19.0
prompt-toolkit              3.0.43
protobuf                    3.20.3
psutil                      5.9.8
ptyprocess                  0.7.0
pure-eval                   0.2.2
pyasn1                      0.5.1
pyasn1-modules              0.3.0
pycocotools                 2.0.7
pycodestyle                 2.11.1
pycparser                   2.21
pyflakes                    3.2.0
Pygments                    2.17.2
pyparsing                   3.1.1
pyquaternion                0.9.9
pytest                      8.0.0
python-dateutil             2.8.2
python-json-logger          2.0.7
pytorch-lightning           0.8.5
pytz                        2024.1
PyWavelets                  1.4.1
PyYAML                      6.0.1
pyzmq                       25.1.2
qtconsole                   5.5.1
QtPy                        2.4.1
referencing                 0.33.0
requests                    2.31.0
requests-oauthlib           1.3.1
rfc3339-validator           0.1.4
rfc3986-validator           0.1.1
rpds-py                     0.17.1
rsa                         4.9
scikit-image                0.18.1
scikit-learn                1.3.2
scipy                       1.10.1
Send2Trash                  1.8.2
sentry-sdk                  1.40.2
setproctitle                1.3.3
setuptools                  68.2.2
shapely                     2.0.2
six                         1.16.0
smmap                       5.0.1
sniffio                     1.3.0
sort-vertices               0.0.0
soupsieve                   2.5
stack-data                  0.6.3
sympy                       1.12
tenacity                    8.2.3
tensorboard                 2.1.1
termcolor                   2.4.0
terminado                   0.18.0
terminaltables              3.1.10
threadpoolctl               3.2.0
tifffile                    2023.7.10
tinycss2                    1.2.1
tomli                       2.0.1
torch                       1.8.0
torchaudio                  0.8.0a0+a751e1d
torchvision                 0.9.0
tornado                     6.4
tqdm                        4.66.1
traitlets                   5.14.1
trimesh                     2.35.39
typeguard                   4.1.5
types-python-dateutil       2.8.19.20240106
typing_extensions           4.9.0
tzdata                      2023.4
uri-template                1.3.0
urllib3                     2.2.0
wandb                       0.16.3
wcwidth                     0.2.13
webcolors                   1.13
webencodings                0.5.1
websocket-client            1.7.0
Werkzeug                    3.0.1
wheel                       0.41.2
widgetsnbextension          4.0.10
yapf                        0.40.1
zipp                        3.17.0

When it goes to decoder, it gives OOM:

File "/home/yating/anaconda3/envs/nerf_8/lib/python3.8/site-packages/diff_gaussian_rasterization/__init__.py", line 210, in forward
    return rasterize_gaussians(
  File "/home/yating/anaconda3/envs/nerf_8/lib/python3.8/site-packages/diff_gaussian_rasterization/__init__.py", line 32, in rasterize_gaussians
    return _RasterizeGaussians.apply(
  File "/home/yating/anaconda3/envs/nerf_8/lib/python3.8/site-packages/diff_gaussian_rasterization/__init__.py", line 92, in forward
    num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
RuntimeError: CUDA out of memory. Tried to allocate 130663.06 GiB (GPU 0; 23.69 GiB total capacity; 6.31 GiB already allocated; 14.75 GiB free; 7.14 GiB reserved in total by PyTorch)
kevinYitshak commented 9 months ago

Hi, for me this fix solved it: https://github.com/graphdeco-inria/gaussian-splatting/issues/99#issuecomment-1771659438

Pixie8888 commented 9 months ago

Hi @kevinYitshak , I still encounter the same error following the link. My environment is pytorch1.8.0 + cu11.1 on RTX 3090Ti.

thucz commented 9 months ago

Hi! @kevinYitshak I also still encounter the same error following the link. My environment is pytorch2.1.0 + cu12.1 on V100.

kevinYitshak commented 8 months ago

Hi! Sorry for the late reply, I am using torch==2.1.2+cu118. I tested on both V100 and A40S GPUs and it works without any issues.

These are the packages I am using for the env: [I am not sure this helps!!]

name: pixelsplat channels:

thucz commented 8 months ago

@kevinYitshak Thanks for your advice! torch=2.1.2+cu118 works in my V100 environment.

dcharatan commented 6 months ago

I've observed this happening if I compile diff-gaussian-rasterization (the original) on a 4090 and then try to use the binaries on an A100. On the other hand, compiling on an A100 and then using the binaries on a 4090 works.