cvg / depthsplat

DepthSplat: Connecting Gaussian Splatting and Depth
https://haofeixu.github.io/depthsplat/
MIT License
534 stars 22 forks source link

OOM when I train with 12 image inputs (customed dataset). (resolution: 256x256) #23

Closed thucz closed 2 days ago

thucz commented 1 week ago

Do you have any idea about how to solve this problem? Any parameters could be revised to reduce GPU memory?

haofeixu commented 1 week ago

Hi, could you share your training config/scripts and your gpu memory such that we could look into this issue? Thanks.

thucz commented 1 week ago

The config is similar to re10k. But the context view number is 12. The target view number is 6. DepthSplat can only be trained when I freeze the depth predictor part.

data_loader.train.batch_size=1 \
dataset.test_chunk_interval=10 \
trainer.val_check_interval=0.5 \
trainer.max_steps=100000 \
model.encoder.upsample_factor=8 \
model.encoder.lowest_feature_resolution=8 \
model.encoder.gaussian_regressor_channels=16 \
model.encoder.feature_upsampler_channels=64 \
model.encoder.return_depth=true \
checkpointing.pretrained_monodepth=pretrained/depth_anything_v2_vits.pth \
checkpointing.pretrained_mvdepth=pretrained/gmflow-scale1-things-e9887eda.pth \
wandb.project=depthsplat \
output_dir=checkpoints/re10k-depthsplat-small \
model.encoder.grid_sample_disable_cudnn=true
haofeixu commented 1 week ago

Hi, for training with more than 2 input views, we suggest to follow the training script on dl3dv (with +experiment=dl3dv): https://github.com/cvg/depthsplat/blob/main/scripts/dl3dv_256x448_depthsplat_base.sh#L27-L64. The view sampler for more than 2 views (fps) is different from re10k (bounded).

In addition, you need to enable local cross-view attention and local cost volume with

model.encoder.multiview_trans_nearest_n_views=3 \
model.encoder.costvolume_nearest_n_views=3 \

Otherwise the model by default uses global attention which would be expensive for many input views.

thucz commented 1 day ago

Thanks!