Pointcept / PointTransformerV2

[NeurIPS'22] An official PyTorch implementation of PTv2.
347 stars 23 forks source link

could you provide the config file for ModelNet40? #9

Closed zhangchbin closed 1 year ago

Gofinge commented 1 year ago

Sure, here is our old config for ModelNet40. It does not support the current version of PCR, but you can edit it to make it support the released version. Also, you need to remove the decoder of PTv2 to do object classification, and you can refer /pcr/model/point_transformer/point_transformer_cls.py:

_base_ = [
    '../_base_/datasets/modelnet40.py',
    '../_base_/schedulers/multi-step_sgd.py',
    '../_base_/tests/classification.py',
    '../_base_/default_runtime.py'
]

train_gpu = [2, 3]

batch_size = 16
batch_size_val = 16
metric = "allAcc"

model = dict(
    type="pt2v2m2c",
    in_channels=6,
    num_classes=40,
    channels=(48, 96, 192, 384, 512),
    embed_depth=1,
    embed_num_samples=8,
    embed_group=4,
    enc_depths=(2, 2, 6, 2),
    dec_depths=(1, 1, 1, 1),
    down_stride=(4, 4, 4, 4),
    down_num_samples=(0.05, 0.1, 0.2, 0.4),  # Gird Size
    attn_groups=(12, 24, 48, 64),
    attn_num_samples=(16, 16, 16, 16),
    attn_qkv_bias=True,
    mlp_channels_expend_ratio=1.,
    drop_rate=0.,
    attn_drop_rate=0.,
    drop_path_rate=0.3,
)

epochs = 100
start_epoch = 0
optimizer = dict(type='SGD', lr=0.05, momentum=0.9, weight_decay=0.0001, nesterov=True)
scheduler = dict(type='MultiStepLR', milestones=[epochs * 0.6, epochs * 0.8], steps_per_epoch=1, gamma=0.1)

# dataset settings
dataset_type = "ModelNetDataset"
data_root = "data/modelnet40_normal_resampled"
cache_data = False
names = ["airplane", "bathtub", "bed", "bench", "bookshelf",
         "bottle", "bowl", "car", "chair", "cone",
         "cup", "curtain", "desk", "door", "dresser",
         "flower_pot", "glass_box", "guitar", "keyboard", "lamp",
         "laptop", "mantel", "monitor", "night_stand", "person",
         "piano", "plant", "radio", "range_hood", "sink",
         "sofa", "stairs", "stool", "table", "tent",
         "toilet", "tv_stand", "vase", "wardrobe", "xbox"]

data = dict(
    num_classes=40,
    ignore_label=-1,  # dummy ignore
    names=names,
    train=dict(
        type=dataset_type,
        split="train",
        data_root=data_root,
        class_names=names,
        transform=[
            dict(type="NormalizeCoord"),
            # dict(type="CenterShift", apply_z=True),
            # dict(type="RandomRotate", angle=[-1, 1], axis='z', center=[0, 0, 0], p=0.5),
            # dict(type="RandomRotate", angle=[-1/24, 1/24], axis='x', p=0.5),
            # dict(type="RandomRotate", angle=[-1/24, 1/24], axis='y', p=0.5),
            dict(type="RandomScale", scale=[0.9, 1.1]),
            dict(type="RandomShift", shift=[0.2, 0.2, 0.2]),
            # dict(type="RandomFlip", p=0.5),
            # dict(type="RandomJitter", sigma=0.005, clip=0.02),
            # dict(type="ElasticDistortion", distortion_params=[[0.2, 0.4], [0.8, 1.6]]),

            # dict(type="Voxelize", voxel_size=0.01, hash_type='fnv', mode='train'),
            # dict(type="SphereCrop", point_max=10000, mode='random'),
            # dict(type="CenterShift", apply_z=True),
            dict(type="ShufflePoint"),
            dict(type="ToTensor")
        ],
        loop=2,
        test_mode=False,
    ),

    val=dict(
        type=dataset_type,
        split="test",
        data_root=data_root,
        class_names=names,
        transform=[
            dict(type="NormalizeCoord"),
            dict(type="ToTensor")
        ],
        loop=1,
        test_mode=False,
    ),

    test=dict(
        type=dataset_type,
        split="test",
        data_root=data_root,
        class_names=names,
        transform=[
            dict(type="NormalizeCoord"),
            dict(type="ToTensor")
        ],
        loop=1,
        test_mode=True,
        test_cfg=dict(
        )
    ),
)

criteria = [
    dict(type="CrossEntropyLoss",
         loss_weight=1.0,
         ignore_index=data["ignore_label"])
]
VLadImirluren commented 7 months ago

Sure, here is our old config for ModelNet40. It does not support the current version of PCR, but you can edit it to make it support the released version. Also, you need to remove the decoder of PTv2 to do object classification, and you can refer /pcr/model/point_transformer/point_transformer_cls.py:

_base_ = [
    '../_base_/datasets/modelnet40.py',
    '../_base_/schedulers/multi-step_sgd.py',
    '../_base_/tests/classification.py',
    '../_base_/default_runtime.py'
]

train_gpu = [2, 3]

batch_size = 16
batch_size_val = 16
metric = "allAcc"

model = dict(
    type="pt2v2m2c",
    in_channels=6,
    num_classes=40,
    channels=(48, 96, 192, 384, 512),
    embed_depth=1,
    embed_num_samples=8,
    embed_group=4,
    enc_depths=(2, 2, 6, 2),
    dec_depths=(1, 1, 1, 1),
    down_stride=(4, 4, 4, 4),
    down_num_samples=(0.05, 0.1, 0.2, 0.4),  # Gird Size
    attn_groups=(12, 24, 48, 64),
    attn_num_samples=(16, 16, 16, 16),
    attn_qkv_bias=True,
    mlp_channels_expend_ratio=1.,
    drop_rate=0.,
    attn_drop_rate=0.,
    drop_path_rate=0.3,
)

epochs = 100
start_epoch = 0
optimizer = dict(type='SGD', lr=0.05, momentum=0.9, weight_decay=0.0001, nesterov=True)
scheduler = dict(type='MultiStepLR', milestones=[epochs * 0.6, epochs * 0.8], steps_per_epoch=1, gamma=0.1)

# dataset settings
dataset_type = "ModelNetDataset"
data_root = "data/modelnet40_normal_resampled"
cache_data = False
names = ["airplane", "bathtub", "bed", "bench", "bookshelf",
         "bottle", "bowl", "car", "chair", "cone",
         "cup", "curtain", "desk", "door", "dresser",
         "flower_pot", "glass_box", "guitar", "keyboard", "lamp",
         "laptop", "mantel", "monitor", "night_stand", "person",
         "piano", "plant", "radio", "range_hood", "sink",
         "sofa", "stairs", "stool", "table", "tent",
         "toilet", "tv_stand", "vase", "wardrobe", "xbox"]

data = dict(
    num_classes=40,
    ignore_label=-1,  # dummy ignore
    names=names,
    train=dict(
        type=dataset_type,
        split="train",
        data_root=data_root,
        class_names=names,
        transform=[
            dict(type="NormalizeCoord"),
            # dict(type="CenterShift", apply_z=True),
            # dict(type="RandomRotate", angle=[-1, 1], axis='z', center=[0, 0, 0], p=0.5),
            # dict(type="RandomRotate", angle=[-1/24, 1/24], axis='x', p=0.5),
            # dict(type="RandomRotate", angle=[-1/24, 1/24], axis='y', p=0.5),
            dict(type="RandomScale", scale=[0.9, 1.1]),
            dict(type="RandomShift", shift=[0.2, 0.2, 0.2]),
            # dict(type="RandomFlip", p=0.5),
            # dict(type="RandomJitter", sigma=0.005, clip=0.02),
            # dict(type="ElasticDistortion", distortion_params=[[0.2, 0.4], [0.8, 1.6]]),

            # dict(type="Voxelize", voxel_size=0.01, hash_type='fnv', mode='train'),
            # dict(type="SphereCrop", point_max=10000, mode='random'),
            # dict(type="CenterShift", apply_z=True),
            dict(type="ShufflePoint"),
            dict(type="ToTensor")
        ],
        loop=2,
        test_mode=False,
    ),

    val=dict(
        type=dataset_type,
        split="test",
        data_root=data_root,
        class_names=names,
        transform=[
            dict(type="NormalizeCoord"),
            dict(type="ToTensor")
        ],
        loop=1,
        test_mode=False,
    ),

    test=dict(
        type=dataset_type,
        split="test",
        data_root=data_root,
        class_names=names,
        transform=[
            dict(type="NormalizeCoord"),
            dict(type="ToTensor")
        ],
        loop=1,
        test_mode=True,
        test_cfg=dict(
        )
    ),
)

criteria = [
    dict(type="CrossEntropyLoss",
         loss_weight=1.0,
         ignore_index=data["ignore_label"])
]

Since I have just come into contact with this project, there are still many things that I am not familiar with..Could you please update this config file to the latest version of pointcept.Thanks~

Gofinge commented 7 months ago

t familiar with..Could you please update this config file to the latest v

Thx for your contribution, I also working on supporting object-level again, it will be added in the next two version