drprojects / superpoint_transformer

Official PyTorch implementation of Superpoint Transformer introduced in [ICCV'23] "Efficient 3D Semantic Segmentation with Superpoint Transformer" and SuperCluster introduced in [3DV'24 Oral] "Scalable 3D Panoptic Segmentation As Superpoint Graph Clustering"
MIT License
551 stars 72 forks source link

Using tiling and voxelization to address CUDA out of Memory errors on custom dataset #144

Closed ImaneTopo closed 1 month ago

ImaneTopo commented 1 month ago

I want to train the model in oder to detect small instances . My scenes are relativelly not large, in that reason I set xy_tiling = None in configs>datamodules>semantic>dales.yaml and also in src>dataset>base.py. Then, I did some modifications in dales.yaml and specifically in voxel = 0.001 and knn_r0.1 But I didn't achieve some good results in instance_iou_object in train and also in val. But when I run the model it displays, Cuda out of memory, I don't know how to fix this problem, and in the same time I don't want to do sampling to my data.

drprojects commented 1 month ago

Hi @ImaneTopo how many points do you have in your cloud ?

The rationale for choosing your xy_tiling, pc_tiling, voxel is not so much in terms of absolute coordinate values (ie my tiles are only 20m wide), but in terms of actual points (ie my tiles contain 1M points). Judging from the very small voxel value you are using, maybe you have an extremely high-resolution point cloud. There are 2 ways to combat the OOM errors stemming from manipulating too many points:

ImaneTopo commented 1 month ago

Thank you for your response: Since my scenes are very small, I removed tiling because it doesn't allow me to process the data properly due to the scenes being very small in size. As I mentioned, my instances are small, and I need centimeter-level precision. My instances have a diameter of 4cm. When I increase the voxel size, the training precision becomes almost zero, but decreasing it still increases it, reaching up to 29% in validation. However, the training was cut off halfway through, indicating CUDA out of memory, even though I am working with 40GB of GPU. Here is the YAML file I am using:

# @package datamodule
defaults:
  - /datamodule/semantic/default.yaml

target: src.datamodules.dales.DALESDataModule

# These parameters are not actually used by the DataModule, but are used
# here to facilitate model parameterization with config interpolation
num_classes: 2
stuff_classes: [0]
trainval: True
val_on_test: True
xy_tiling: None

# Features that will be computed, saved, loaded for points and segments

# point features used for the partition
partition_hf:
  - 'linearity'
  - 'planarity'
  - 'scattering'
  - 'elevation'

# point features used for training
point_hf:
  - 'intensity'
  - 'linearity'
  - 'planarity'
  - 'scattering'
  - 'verticality'
  - 'elevation'

# segment-wise features computed at preprocessing
segment_base_hf: []

# segment features computed as the mean of point feature in each
# segment, saved with "mean_" prefix
segment_mean_hf: []

# segment features computed as the std of point feature in each segment,
# saved with "std_" prefix
segment_std_hf: []

# horizontal edge features used for training
edge_hf:
  - 'mean_off'
  - 'std_off'
  - 'mean_dist'
  - 'angle_source'
  - 'angle_target'
  - 'centroid_dir'
  - 'centroid_dist'
  - 'normal_angle'
  - 'log_length'
  - 'log_surface'
  - 'log_volume'
  - 'log_size'

v_edge_hf: []  # vertical edge features used for training

# Parameters declared here to facilitate tuning configs without copying
# all the pre_transforms
voxel: 0.001
knn: 25
knn_r: 0.1
knn_step: -1
knn_min_search: 10
ground_threshold: 5
ground_scale: 20
pcp_regularization: [0.1, 0.2, 0.3]
pcp_spatial_weight: [1e-1, 1e-2, 1e-3]
pcp_cutoff: [10, 30, 100]
pcp_k_adjacency: 10
pcp_w_adjacency: 1
pcp_iterations: 15
graph_k_min: 1
graph_k_max: 30
graph_gap: [5, 30, 30]
graph_se_ratio: 0.3
graph_se_min: 20
graph_cycles: 3
graph_margin: 0.5
graph_chunk: [1e6, 1e5, 1e5]  # reduce if CUDA memory errors

# Batch construction parameterization
sample_segment_ratio: 0.3
sample_segment_by_size: True
sample_segment_by_class: True
sample_point_min: 32
sample_point_max: 128
sample_graph_r: 50  # set to r<=0 to skip SampleRadiusSubgraphs
sample_graph_k: 4
sample_graph_disjoint: True
sample_edge_n_min: -1  # [5, 5, 15]
sample_edge_n_max: -1  # [10, 15, 25]

# Augmentations parameterization
pos_jitter: 0.05
tilt_n_rotate_phi: 0.1
tilt_n_rotate_theta: 180
anisotropic_scaling: 0.2
node_feat_jitter: 0
h_edge_feat_jitter: 0
v_edge_feat_jitter: 0
node_feat_drop: 0
h_edge_feat_drop: 0.3
v_edge_feat_drop: 0
node_row_drop: 0
h_edge_row_drop: 0
v_edge_row_drop: 0
drop_to_mean: False

# Preprocessing
pre_transform:
    - transform: SaveNodeIndex
      params:
        key: 'sub'
    - transform: DataTo
      params:
        device: 'cuda'
    - transform: GridSampling3D
      params:
        size: ${datamodule.voxel}
        hist_key: 'y'
        hist_size: ${eval:'${datamodule.num_classes} + 1'}
    - transform: KNN
      params:
        k: ${datamodule.knn}
        r_max: ${datamodule.knn_r}
        verbose: False
    - transform: DataTo
      params:
        device: 'cpu'
    - transform: GroundElevation
      params:
        threshold: ${datamodule.ground_threshold}
        scale: ${datamodule.ground_scale}
    - transform: PointFeatures
      params:
        keys: ${datamodule.point_hf_preprocess}
        k_min: 1
        k_step: ${datamodule.knn_step}
        k_min_search: ${datamodule.knn_min_search}
    - transform: DataTo
      params:
        device: 'cuda'
    - transform: AdjacencyGraph
      params:
        k: ${datamodule.pcp_k_adjacency}
        w: ${datamodule.pcp_w_adjacency}
    - transform: ConnectIsolated
      params:
        k: 1
    - transform: DataTo
      params:
        device: 'cpu'
    - transform: AddKeysTo  # move some features to 'x' to be used for partition
      params:
        keys: ${datamodule.partition_hf}
        to: 'x'
        delete_after: False
    - transform: CutPursuitPartition
      params:
        regularization: ${datamodule.pcp_regularization}
        spatial_weight: ${datamodule.pcp_spatial_weight}
        k_adjacency: ${datamodule.pcp_k_adjacency}
        cutoff: ${datamodule.pcp_cutoff}
        iterations: ${datamodule.pcp_iterations}
        parallel: True
        verbose: False
    - transform: NAGRemoveKeys  # remove 'x' used for partition (features are still preserved under their respective Data attributes)
      params:
        level: 'all'
        keys: 'x'
    - transform: NAGTo
      params:
        device: 'cuda'
    - transform: SegmentFeatures
      params:
        n_min: 10
        n_max: 60
        keys: ${datamodule.segment_base_hf_preprocess}
        mean_keys: ${datamodule.segment_mean_hf_preprocess}
        std_keys: ${datamodule.segment_std_hf_preprocess}
        strict: False  # will not raise error if a mean or std key is missing
    - transform: RadiusHorizontalGraph
      params:
        k_min: ${datamodule.graph_k_min}
        k_max: ${datamodule.graph_k_max}
        gap: ${datamodule.graph_gap}
        se_ratio: ${datamodule.graph_se_ratio}
        se_min: ${datamodule.graph_se_min}
        cycles: ${datamodule.graph_cycles}
        margin: ${datamodule.graph_margin}
        chunk_size: ${datamodule.graph_chunk}
        halfspace_filter: True
        bbox_filter: True
        target_pc_flip: True
        source_pc_sort: False
        keys: ['mean_off', 'std_off', 'mean_dist' ]
    - transform: NAGTo
      params:
        device: 'cpu'

# CPU-based train transforms
train_transform: null

# CPU-based val transforms
val_transform: ${datamodule.train_transform}

# CPU-based test transforms
test_transform: ${datamodule.val_transform}

# GPU-based train transforms
on_device_train_transform:

    # Add a node_size attribute to all segments, this is needed for
    # segment-wise position normalization with UnitSphereNorm
    - transform: NodeSize

    # Apply sampling transforms first to reduce the number of nodes and
    # edges. These operations are compute-intensive and are the reason
    # why these transforms are not performed on CPU

    - transform: NAGRestrictSize
      params:
        level: '1+'
        num_nodes: ${datamodule.max_num_nodes}

    # Cast all attributes to either float or long. Doing this only now
    # allows speeding up disk I/O and CPU->GPU transfer
    - transform: NAGCast

    # Apply geometric transforms affecting position, offsets, normals
    # before calling transforms relying on those, such as on-the-fly
    # edge features computation
    - transform: NAGJitterKey
      params:
        key: 'pos'
        sigma: ${datamodule.pos_jitter}
        trunc: ${datamodule.voxel}
    - transform: RandomTiltAndRotate
      params:
        phi: ${datamodule.tilt_n_rotate_phi}
        theta: ${datamodule.tilt_n_rotate_theta}
    - transform: RandomAnisotropicScale
      params:
        delta: ${datamodule.anisotropic_scaling}
    - transform: RandomAxisFlip
      params:
        p: 0.5

    # Compute some horizontal and vertical edges on-the-fly. Those are
    # only computed now since they can be deduced from point and node
    # attributes. Besides, the OnTheFlyHorizontalEdgeFeatures transform
    # takes a trimmed graph as input and doubles its size, creating j->i
    # for each input i->j edge
    - transform: OnTheFlyHorizontalEdgeFeatures
      params:
        keys: ${datamodule.edge_hf}
        use_mean_normal: ${eval:'"normal" in ${datamodule.segment_mean_hf}'}
    - transform: OnTheFlyVerticalEdgeFeatures
      params:
        keys: ${datamodule.v_edge_hf}
        use_mean_normal: ${eval:'"normal" in ${datamodule.segment_mean_hf}'}

    # Edge sampling is only performed after the horizontal graph is
    # untrimmed by OnTheFlyHorizontalEdgeFeatures
    - transform: NAGRestrictSize
      params:
        level: '1+'
        num_edges: ${datamodule.max_num_edges}

    # Move all point and segment features to 'x'
    - transform: NAGAddKeysTo
      params:
        level: 0
        keys: ${eval:'ListConfig(${datamodule.point_hf})'}
        to: 'x'
    - transform: NAGAddKeysTo
      params:
        level: '1+'
        keys: ${eval:'ListConfig(${datamodule.segment_hf})'}
        to: 'x'

    # Add some noise and randomly some point, node and edge features
    - transform: NAGJitterKey
      params:
        key: 'x'
        sigma: ${datamodule.node_feat_jitter}
        trunc: ${eval:'2 * ${datamodule.node_feat_jitter}'}
    - transform: NAGJitterKey
      params:
        key: 'edge_attr'
        sigma: ${datamodule.h_edge_feat_jitter}
        trunc: ${eval:'2 * ${datamodule.h_edge_feat_jitter}'}
    - transform: NAGJitterKey
      params:
        key: 'v_edge_attr'
        sigma: ${datamodule.v_edge_feat_jitter}
        trunc: ${eval:'2 * ${datamodule.v_edge_feat_jitter}'}
    - transform: NAGDropoutColumns
      params:
        p: ${datamodule.node_feat_drop}
        key: 'x'
        inplace: True
        to_mean: ${datamodule.drop_to_mean}
    - transform: NAGDropoutColumns
      params:
        p: ${datamodule.h_edge_feat_drop}
        key: 'edge_attr'
        inplace: True
        to_mean: ${datamodule.drop_to_mean}
    - transform: NAGDropoutColumns
      params:
        p: ${datamodule.v_edge_feat_drop}
        key: 'v_edge_attr'
        inplace: True
        to_mean: ${datamodule.drop_to_mean}
    - transform: NAGDropoutRows
      params:
        p: ${datamodule.node_row_drop}
        key: 'x'
        to_mean: ${datamodule.drop_to_mean}
    - transform: NAGDropoutRows
      params:
        p: ${datamodule.h_edge_row_drop}
        key: 'edge_attr'
        to_mean: ${datamodule.drop_to_mean}
    - transform: NAGDropoutRows
      params:
        p: ${datamodule.v_edge_row_drop}
        key: 'v_edge_attr'
        to_mean: ${datamodule.drop_to_mean}

    # Add self-loops in the horizontal graph
    - transform: NAGAddSelfLoops

    # Compute the instance graph for instantiation
    # NB: setting datamodule.instance: False will skip this step
    - transform: OnTheFlyInstanceGraph
      params:
        level: ${eval:'1 if ${datamodule.instance} else -1'}
        num_classes: ${datamodule.num_classes}
        k_max: ${datamodule.instance_k_max}
        radius: ${datamodule.instance_radius}

# GPU-based val transforms
on_device_val_transform:

    # Add a node_size attribute to all segments, this is needed for
    # segment-wise position normalization with UnitSphereNorm
    - transform: NodeSize

    # Cast all attributes to either float or long. Doing this only now
    # allows speeding up disk I/O and CPU->GPU transfer
    - transform: NAGCast

    # Compute some horizontal and vertical edges on-the-fly. Those are
    # only computed now since they can be deduced from point and node
    # attributes. Besides, the OnTheFlyHorizontalEdgeFeatures transform
    # takes a trimmed graph as input and doubles its size, creating j->i
    # for each input i->j edge
    - transform: OnTheFlyHorizontalEdgeFeatures
      params:
        keys: ${datamodule.edge_hf}
        use_mean_normal: ${eval:'"normal" in ${datamodule.segment_mean_hf}'}
    - transform: OnTheFlyVerticalEdgeFeatures
      params:
        keys: ${datamodule.v_edge_hf}
        use_mean_normal: ${eval:'"normal" in ${datamodule.segment_mean_hf}'}

    # Move all point and segment features to 'x'
    - transform: NAGAddKeysTo
      params:
        level: 0
        keys: ${eval:'ListConfig(${datamodule.point_hf})'}
        to: 'x'
    - transform: NAGAddKeysTo
      params:
        level: '1+'
        keys: ${eval:'ListConfig(${datamodule.segment_hf})'}
        to: 'x'

    # Add self-loops in the horizontal graph
    - transform: NAGAddSelfLoops

    # Compute the instance graph for instantiation
    # NB: setting datamodule.instance: False will skip this step
    - transform: OnTheFlyInstanceGraph
      params:
        level: ${eval:'1 if ${datamodule.instance} else -1'}
        num_classes: ${datamodule.num_classes}
        k_max: ${datamodule.instance_k_max}
        radius: ${datamodule.instance_radius}

# GPU-based test transforms
on_device_test_transform: ${datamodule.on_device_val_transform}
drprojects commented 1 month ago

Like I said, you should use tiling then. Maybe your scenes are small in metric size, but if you are using millimetric voxels for centimetric instances, then I guess your cloud is extremely dense. Aim for tiles with ~1 to ~10M points roughly.

ImaneTopo commented 1 month ago

Like I said, you should use tiling then. Maybe your scenes are small in metric size, but if you are using millimetric voxels for centimetric instances, then I guess your cloud is extremely dense. Aim for tiles with ~1 to ~10M points roughly.

Thank you so much for your response, in that point, whish tiling would be better to apply : xy_tiling or pc_tiling , because in dales.yaml there's only xy_tiling that is mentionned and which have a specific value, could I replace it with pc_tiling?

Another thing, in my case, I have 2 classes, but I am interested in only having better segmentation of one(wich concern the object of my interest), but the problem is that the other class is so dominant compared to my instance class; I had a thought to apply sampling only on the other class in order to enhance the results of iou of my instance classe, but I don't know if is that possible and how can I do it?

drprojects commented 1 month ago

Yes you can tune xy_tiling and pc_tiling to your needs. You can only use one at once. Look at the Superpoint Transformer 🧑‍🏫 tutorial notebook, I created a section for playing with it and understanding what each type of tiling does. If none of these tiling strategies suit your needs, I am sure you can figure out a way to subdivide your data in another way. Feel free to send us a PR if you create a new cool tiling utility :wink:

There are 2 strategies for driving your model to focus on hard classes:

The rest is up to you !