Pointcept / PointTransformerV3

[CVPR'24 Oral] Official repository of Point Transformer V3 (PTv3)
MIT License
583 stars 30 forks source link

How to modify model parameters to reduce excessive GPU memory consumption #20

Closed pmj110119 closed 3 months ago

pmj110119 commented 3 months ago

Hi, thank you for your outstanding work!

I'm looking to employ PT-V3 as my point backbone with an output feature channel of 512. However, using the default settings, I encountered a CUDA OOM error in forward process.

Comparatively, with my data (batch size=8, number of points=15,000), SparseUnet requires only 2GB of memory, while PT-V3 demands a substantial 20GB. This disparity is quite significant, and I'm wondering if there might be an error in my configuration?

My code:

import torch
from PointTransformerV3.model import PointTransformerV3

bs = 8
npts = 15000
len_xyz = 3
feat_dims = 3
grid_size = 0.005
offset = [npts*(i+1) for i in range(bs)]
offset = torch.tensor(offset).cuda()

pts = torch.rand(bs,npts, len_xyz).cuda()
coord = pts / grid_size
feat = torch.ones_like(pts)

patch_sz = 1024
model = PointTransformerV3(
        in_channels=3,
        stride=(2, 2, 2, 2),
        enc_depths=(2, 2, 2, 2, 2),
        enc_channels=(16, 32, 64, 128, 256),
        enc_num_head=(2, 4, 8, 16, 32),
        enc_patch_size=(patch_sz, patch_sz, patch_sz, patch_sz, patch_sz),
        dec_depths=(2, 2, 2, 2),
        dec_channels=(512, 256, 256, 256),
        dec_num_head=(4, 4, 8, 16),
        dec_patch_size=(patch_sz, patch_sz, patch_sz, patch_sz),
        mlp_ratio=2,
    ).cuda()

data_dict = dict(
    coord = coord,
    feat = feat,
    offset = offset,
    grid_size = grid_size
)
data_dict['coord'] = data_dict['coord'].reshape(bs*npts,3)
data_dict['feat'] = data_dict['feat'].reshape(bs*npts,3)

pred = model(data_dict)
Gofinge commented 3 months ago

Hi, have you made sure flash_attention is already enabled? Without flash attention, large matrix multiples (1024 x 1024) occur in kernels, which might cause OOM. If flash attention can not be enabled when limited by the local environment, you can consider reducing the patch size to 128 or 256 level.

Gofinge commented 3 months ago

Thanks for providing the demo code. I run it with a few modifications on my local environment and it can be forwarded on my local machine with one single 4090. So some initial model parameters cause the issue, please further check the model parameters are aligned with our official configs (https://github.com/Pointcept/Pointcept/blob/main/configs/scannet200/semseg-pt-v3m1-0-base.py#L15-L51). image

pmj110119 commented 3 months ago

Thanks for providing the demo code. I run it with a few modifications on my local environment and it can be forwarded on my local machine with one single 4090. So some initial model parameters cause the issue, please further check the model parameters are aligned with our official configs (https://github.com/Pointcept/Pointcept/blob/main/configs/scannet200/semseg-pt-v3m1-0-base.py#L15-L51). image

Yes, a 4090 (24GB) GPU can manage the forward process, but computing the loss and executing loss.backward() consumes a significant amount of GPU memory, leading to an Out Of Memory (OOM) situation.

My confusion lies in the fact that both models have four stages and an output channel of 512, yet SparseUnet's forward process only occupies 2GB, whereas PT-V3 consumes 10 times that amount of memory. Could you share the GPU memory usage observed in your tests?

pmj110119 commented 3 months ago

Hi, have you made sure flash_attention is already enabled? Without flash attention, large matrix multiples (1024 x 1024) occur in kernels, which might cause OOM. If flash attention can not be enabled when limited by the local environment, you can consider reducing the patch size to 128 or 256 level.↳

Yes, I set enable_flash=True.

Gofinge commented 3 months ago

Okay, I noticed that you set the decoder channels as [512, 256, 256, 256], instead of our default setting, which is [384, 192, 96, 48]. So in the last stage of the decoding process, we have a huge amount of points (120,000). We cannot avoid consuming huge memories with attention mechanisms when we also force point feature channels with a large number.

Gofinge commented 3 months ago

If you have to keep the output channels with 256, I think maybe you can try to reduce the number of decoder blocks of the last one or two stages to 0 (I think a complex decoding process is not necessary. Also for image transformers, they also use a light-weight patch encoding to avoid directly attention with pixels).

pmj110119 commented 3 months ago

Thanks for reply!! I set dec_depths=(2,2,0,0) and dec_channels=((512, 16, 16, 16)) (the feature channel is quiet small), yet the CUDA memory usage still peaks at 18GB.

It seems that, regardless of how compact the network is, a large number of points will inevitably consume substantial amounts of GPU memory?

model = PointTransformerV3(
        in_channels=3,
        stride=(2, 2, 2, 2),
        enc_depths=(2, 2, 2, 2, 2),
        enc_channels=(16, 32, 64, 128, 256),
        enc_num_head=(2, 4, 8, 16, 32),
        enc_patch_size=(patch_sz, patch_sz, patch_sz, patch_sz, patch_sz),
        dec_depths=(2, 2, 0, 0),
        dec_channels=((512, 16, 16, 16)),
        dec_num_head=(4, 4, 4, 4),
        dec_patch_size=(patch_sz, patch_sz, patch_sz, patch_sz),
        mlp_ratio=2,
    ).cuda()
Gofinge commented 3 months ago

Okay, I think I know the reason. As the simulated input data do not satisfy the data distribution of real-world data, the points almost didn't downsampled after each pooling (controlled with grid size).

pmj110119 commented 3 months ago

Unfortunately, I load real-word data ( 8x15000x3 point clouds, the xyz range is about 1m), but it still consume about 18GB 😭 pts.zip (unzip and get pts.npy ndarray)

Gofinge commented 3 months ago

Unfortunately, I load real-word data ( 8x15000x3 point clouds, the xyz range is about 1m), but it still consume about 18GB 😭 pts.zip (unzip and get pts.npy ndarray)

grid_size = 0.005 -> 0.05 (we use 2mm for indoor scene and 5mm for outdoor scene)

pmj110119 commented 3 months ago

Thanks, the GPU memory consumption has been reduced to 13GB. But setting grid_size = 0.05 refer to 5cm ( my data is table-scene). I suspect this might be too sparse, potentially rendering it insufficient for effective point segmentation?

Gofinge commented 3 months ago

To be honest, the memory consumption is strange. In my local environment, I set batch size as 2 and tested with the ScanNet dataset, the number of point feeds to the model is about 204,800, and the memory cost is as follows (Base consumption for my PC is about 2.2G):

image

And here is for batch size 4 (409,600 points) for one single GPU: image

pmj110119 commented 3 months ago

Could you share what your dec_channels setting is?

Gofinge commented 3 months ago

Same as our default configs for all settings (https://github.com/Pointcept/Pointcept/blob/main/configs/scannet200/semseg-pt-v3m1-0-base.py#L15-L51)

pmj110119 commented 3 months ago

Hi, I used the following code to handle point cloud coordinates incorrectly (which used in sparseUnet).

coord = pts / grid_size

After deleting this line of code, the memory cost is significantly reduced. Thank you very much for your patient help!