Cavendish518 / SFPNet

This repo is the official code for the paper SFPNet (ECCV 2024) and dataset S.MID.
Other
56 stars 1 forks source link

关于SparseFocalModulation部分 #3

Closed hymhymhym123 closed 2 days ago

hymhymhym123 commented 1 week ago

作者您好,我尝试提取您的SparseFocalModulation这部分代码到SphereFormer,SphereFormer的输入是feats:[N,c] , xyz:[N,3] , batch:[N],但我不太清楚您在x = self.modulation(x)的输入x包含信息以及大致形状是怎样的,我该如何转换数据形式呢

Cavendish518 commented 1 week ago

你好,x和forward注释部分保持一致,是spconv库的sparsetensor,可以查看一下spconv库相关代码如何构建sparsetensor。和sphereformer对应的话可以参考backbone文件。祝顺利!

hymhymhym123 commented 1 week ago

您好,我查看了backbone文件,可能因为我自身能力的问题,我想问下unet中需要传入的参数focal_r,focal_th,focal_h,分别代表什么吗

Cavendish518 commented 1 week ago

您好,我查看了backbone文件,可能因为我自身能力的问题,我想问下unet中需要传入的参数focal_r,focal_th,focal_h,分别代表什么吗

这个是控制3个focal level 中 3个空间维度的稀疏卷积核的超参数,目前对于市面上的主流激光雷达不需要调整该参数。祝好。

hymhymhym123 commented 1 week ago

您好,我的代码现在能够跑起来,但是会在训练中会发生这样的提示,我不确定是否发生了梯度爆炸或是其他情况,您有遇到类似的情况解决经验可以分享给我吗: [11/21 20:19:10 main-logger]: Epoch: [1/50][1020/4782] Data 0.001 (0.006) Batch 0.883 (0.915) Remain 60:31:03 Loss nan Lr: [0.00597696, 0.0005977] Accuracy 0.0619. NaN or Inf found in input tensor. NaN or Inf found in input tensor. NaN or Inf found in input tensor.

这是我根据sphereformer的unet_spherical_transformer文件和您的backbone文件修改的网络骨干文件(我可能犯了很低级的错误。。): import functools import warnings import torch import torch.nn as nn import numpy as np import spconv.pytorch as spconv from spconv.pytorch.modules import SparseModule from spconv.core import ConvAlgo from collections import OrderedDict from torch_scatter import scatter_mean from model.spherical_transformer import SphereFormer

add

from model.SFPNet import ResFBlock

class ResidualBlock(SparseModule): def init(self, in_channels, out_channels, norm_fn, indice_key=None): super().init() if in_channels == out_channels: self.i_branch = spconv.SparseSequential( nn.Identity() ) else: self.i_branch = spconv.SparseSequential( spconv.SubMConv3d(in_channels, out_channels, kernel_size=1, bias=False) ) self.conv_branch = spconv.SparseSequential( norm_fn(in_channels), nn.ReLU(), spconv.SubMConv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=False, indice_key=indice_key), norm_fn(out_channels), nn.ReLU(), spconv.SubMConv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=False, indice_key=indice_key) )

def forward(self, input):
    identity = spconv.SparseConvTensor(input.features, input.indices, input.spatial_shape, input.batch_size)
    output = self.conv_branch(input)
    output = output.replace_feature(output.features + self.i_branch(identity).features)
    return output

class VGGBlock(SparseModule): def init(self, in_channels, out_channels, norm_fn, indice_key=None): super().init() self.conv_layers = spconv.SparseSequential( norm_fn(in_channels), nn.ReLU(), spconv.SubMConv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=False, indice_key=indice_key) )

def forward(self, input):
    return self.conv_layers(input)

def get_downsample_info(xyz, batch, indice_pairs): pair_in, pair_out = indice_pairs[0], indice_pairs[1] valid_mask = (pair_in != -1) valid_pair_in, valid_pair_out = pair_in[valid_mask].long(), pair_out[valid_mask].long() xyz_next = scatter_mean(xyz[valid_pair_in], index=valid_pair_out, dim=0) batch_next = scatter_mean(batch.float()[valid_pair_in], index=valid_pair_out, dim=0) return xyz_next, batch_next

class Cosine_aug(nn.Module): """ Frequency augmentation (optional) """

def __init__(self, in_dim=3, out_dim=6, alpha=10000, beta=1):
    super().__init__()
    self.in_dim = in_dim  # 输入维度,通常表示坐标的维度(如 3D 坐标)
    self.out_dim = out_dim  # 输出维度,表示生成的特征维度
    self.alpha, self.beta = alpha, beta  # alpha 和 beta 是用于特征缩放的参数

def forward(self, xyz):
    # 获取输入张量的当前设备(GPU 或 CPU)
    cur_dev = xyz.get_device()

    # 获取输入张量的形状,N 是点的数量,_ 是输入维度(应该为 in_dim)
    N, _ = xyz.shape

    # 计算特征维度,out_dim 是输出维度,in_dim 是输入维度
    feat_dim = self.out_dim // (self.in_dim * 2)

    # 创建一个范围张量,从 0 到 feat_dim-1
    feat_range = torch.arange(feat_dim).float().to(cur_dev)

    # 计算维度嵌入,使用 beta 的幂来缩放特征
    dim_embed = torch.pow(self.beta, feat_range / feat_dim)

    # 将 xyz 张量扩展一个维度,并进行缩放
    div_embed = torch.div(self.alpha * xyz.unsqueeze(-1), dim_embed)

    # 计算正弦嵌入
    sin_embed = torch.sin(div_embed)

    # 计算余弦嵌入
    cos_embed = torch.cos(div_embed)

    # 将正弦和余弦嵌入堆叠在一起,形成一个新的张量
    position_embed = torch.stack([sin_embed, cos_embed], dim=3).flatten(2)

    # 将嵌入的形状调整为 (N, out_dim),以便输出
    position_embed = position_embed.reshape(N, self.out_dim)

    # 返回生成的位置嵌入特征
    return position_embed

class UBlock(nn.Module): def init(self, nPlanes, focal_r, focal_th, focal_h, norm_fn, block_reps, block, window_size, window_size_sphere, quant_size, quant_size_sphere, head_dim=16, window_size_scale=[2.0, 2.0], rel_query=True, rel_key=True, rel_value=True, drop_path=0.0, indice_key_id=1, grad_checkpoint_layers=[], sphere_layers=[1,2,3,4,5], a=0.05*0.25, ctx_mode=0, ):

    super().__init__()
    # add
    focal_level = 3

    self.ctx_mode = ctx_mode  # 上下文模式
    self.nPlanes = nPlanes
    self.indice_key_id = indice_key_id
    self.grad_checkpoint_layers = grad_checkpoint_layers
    self.sphere_layers = sphere_layers

    blocks = {'block{}'.format(i): block(nPlanes[0], nPlanes[0], norm_fn, indice_key='subm{}'.format(indice_key_id)) for i in range(block_reps)}
    blocks = OrderedDict(blocks)
    self.blocks = spconv.SparseSequential(blocks)

    if indice_key_id in sphere_layers:

        # add
        self.focalblk = ResFBlock(nPlanes[0], drop=0., drop_path=drop_path[0], focal_level=focal_level,
                                  focal_x=focal_r[0],
                                  focal_y=focal_th[0], focal_z=focal_h[0],
                                  indice_key='focal{}'.format(indice_key_id))

        self.window_size = window_size
        self.window_size_sphere = window_size_sphere
        num_heads = nPlanes[0] // head_dim
        self.transformer_block = SphereFormer(
            nPlanes[0],
            num_heads,
            window_size,
            window_size_sphere,
            quant_size,
            quant_size_sphere,
            indice_key='sphereformer{}'.format(indice_key_id),
            rel_query=rel_query,
            rel_key=rel_key,
            rel_value=rel_value,
            drop_path=drop_path[0],
            a=a,
        )

    if len(nPlanes) > 1:
        self.conv = spconv.SparseSequential(
            norm_fn(nPlanes[0]),
            nn.ReLU(),
            spconv.SparseConv3d(nPlanes[0], nPlanes[1], kernel_size=2, stride=2, bias=False, indice_key='spconv{}'.format(indice_key_id), algo=ConvAlgo.Native)
        )

        window_size_scale_cubic, window_size_scale_sphere = window_size_scale
        window_size_next = np.array([
            window_size[0]*window_size_scale_cubic,
            window_size[1]*window_size_scale_cubic,
            window_size[2]*window_size_scale_cubic
        ])
        quant_size_next = np.array([
            quant_size[0]*window_size_scale_cubic,
            quant_size[1]*window_size_scale_cubic,
            quant_size[2]*window_size_scale_cubic
        ])
        window_size_sphere_next = np.array([
            window_size_sphere[0]*window_size_scale_sphere,
            window_size_sphere[1]*window_size_scale_sphere,
            window_size_sphere[2]
        ])
        quant_size_sphere_next = np.array([
            quant_size_sphere[0]*window_size_scale_sphere,
            quant_size_sphere[1]*window_size_scale_sphere,
            quant_size_sphere[2]
        ])

        self.u = UBlock(nPlanes[1:],
            focal_r[1:],
            focal_th[1:],
            focal_h[1:],
            norm_fn,
            block_reps,
            block,
            window_size_next,
            window_size_sphere_next,
            quant_size_next,
            quant_size_sphere_next,
            window_size_scale=window_size_scale,
            rel_query=rel_query,
            rel_key=rel_key,
            rel_value=rel_value,
            drop_path=drop_path[1:],
            indice_key_id=indice_key_id+1,
            grad_checkpoint_layers=grad_checkpoint_layers,
            sphere_layers=sphere_layers,
            a=a,
            ctx_mode=0,
        )

        self.deconv = spconv.SparseSequential(
            norm_fn(nPlanes[1]),
            nn.ReLU(),
            spconv.SparseInverseConv3d(nPlanes[1], nPlanes[0], kernel_size=2, bias=False, indice_key='spconv{}'.format(indice_key_id), algo=ConvAlgo.Native)
        )

        blocks_tail = {}
        for i in range(block_reps):
            blocks_tail['block{}'.format(i)] = block(nPlanes[0] * (2 - i), nPlanes[0], norm_fn, indice_key='subm{}'.format(indice_key_id))
        blocks_tail = OrderedDict(blocks_tail)
        self.blocks_tail = spconv.SparseSequential(blocks_tail)

def forward(self, inp, xyz, batch):

    assert (inp.indices[:, 0] == batch).all()

    output = self.blocks(inp)

    # transformer
    if self.indice_key_id in self.sphere_layers:
        '''
        if self.indice_key_id in self.grad_checkpoint_layers:
            def run(feats_, xyz_, batch_):
                return self.transformer_block(feats_, xyz_, batch_)
            transformer_features = torch.utils.checkpoint.checkpoint(run, output.features, xyz, batch)
        else:
            transformer_features = self.transformer_block(output.features, xyz, batch)
        output = output.replace_feature(transformer_features)
        '''
        # add
        output = self.focalblk(output)

    identity = spconv.SparseConvTensor(output.features, output.indices, output.spatial_shape, output.batch_size)

    if len(self.nPlanes) > 1:
        output_decoder = self.conv(output)

        # downsample
        indice_pairs = output_decoder.indice_dict['spconv{}'.format(self.indice_key_id)].indice_pairs
        xyz_next, batch_next = get_downsample_info(xyz, batch, indice_pairs)

        output_decoder = self.u(output_decoder, xyz_next, batch_next.long())
        output_decoder = self.deconv(output_decoder)
        output = output.replace_feature(torch.cat((identity.features, output_decoder.features), dim=1))
        output = self.blocks_tail(output)

    return output

class Semantic(nn.Module): def init(self, focal_th, focal_r, focal_h, input_c, m, classes, block_reps, block_residual, layers, window_size, window_size_sphere, quant_size, quant_size_sphere, rel_query=True, rel_key=True, rel_value=True, drop_path_rate=0.0, window_size_scale=2.0, grad_checkpoint_layers=[], sphere_layers=[1,2,3,4,5], a=0.05*0.25, ): super().init()

    norm_fn = functools.partial(nn.BatchNorm1d, eps=1e-4, momentum=0.1)

    if block_residual:
        block = ResidualBlock
    else:
        block = VGGBlock

    dpr = [x.item() for x in torch.linspace(0, drop_path_rate, 7)]

    #### backbone
    self.pose_ini = Cosine_aug(3, 6, 10000, 1)
    self.input_conv = spconv.SparseSequential(
        spconv.SubMConv3d(input_c+6, m, kernel_size=3, padding=1, bias=False, indice_key='subm1')
    )

    self.unet = UBlock(layers,
        focal_r,
        focal_th,
        focal_h,
        norm_fn,
        block_reps,
        block,
        window_size,
        window_size_sphere,
        quant_size,
        quant_size_sphere,
        window_size_scale=window_size_scale,
        rel_query=rel_query,
        rel_key=rel_key,
        rel_value=rel_value,
        drop_path=dpr,
        indice_key_id=1,
        grad_checkpoint_layers=grad_checkpoint_layers,
        sphere_layers=sphere_layers,
        a=a,
    )

    self.output_layer = spconv.SparseSequential(
        norm_fn(m),
        nn.ReLU()
    )
    self.apply(self.set_bn_init)
    #### semantic segmentation
    self.linear = nn.Linear(m, classes) # bias(default): True

@staticmethod
def set_bn_init(m):
    classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1:
        m.weight.data.fill_(1.0)
        m.bias.data.fill_(0.0)

def forward(self, input, xyz, batch):
    '''
    :param input_map: (N), int, cuda
    '''DATA:

data_name: semantic_kitti data_root: /home/hym/下载/SemanticKITTI/dataset label_mapping: util/semantic-kitti.yaml classes: 19 fea_dim: 6 voxel_size: [0.05, 0.05, 0.05] voxel_max: 120000

TRAIN:

arch

arch: unet_spherical_transformer input_c: 4 m: 32 block_reps: 2 block_residual: True layers: [32, 64, 128, 256, 256] quant_size_scale: 24 patch_size: 1 window_size: 6 use_xyz: True sync_bn: True # adopt sync_bn or not rel_query: True rel_key: True rel_value: True drop_path_rate: 0.3 max_batch_points: 1000000 class_weight: [ 3.1557, 8.7029, 7.8281, 6.1354, 6.3161, 7.9937, 8.9704, 10.1922, 1.6155, 4.2187, 1.9385, 5.5455, 2.0198, 2.6261, 1.3212, 5.1102, 2.5492, 5.8585, 7.3929] xyz_norm: False pc_range: [[-51.2, -51.2, -4], [51.2, 51.2, 2.4]] window_size_sphere: [2, 2, 80] window_size_scale: [2.0, 1.5] sphere_layers: [1,2,3,4,5] grad_checkpoint_layers: [] a: 0.0125 loss_name: ce_loss use_tta: False vote_num: 4

training

aug: True transformer_lr_scale: 0.1 scheduler_update: step scheduler: Poly

power: 0.9 use_amp: True train_gpu: [0] workers: 16 # data loader workers batch_size: 8 # batch size for training batch_size_val: 8 # batch size for validation during training, memory and speed tradeoff base_lr: 0.006 epochs: 50 start_epoch: 0 momentum: 0.9 weight_decay: 0.02 drop_rate: 0.5

ignore_label: 255 manual_seed: 123 print_freq: 10 save_freq: 1 save_path: runs/semantic_kitti_unet32_spherical_transformer weight: # path to initial weight (default: none) resume: # path to latest checkpoint (default: none) evaluate: True # evaluate on validation set, extra gpu memory needed and small batch_size_val is recommend eval_freq: 1 val: False

Distributed: dist_url: tcp://127.0.0.1:6789 dist_backend: 'nccl' multiprocessing_distributed: True world_size: 1 rank: 0

    ret_new_cos = []
    for i in range(len(torch.unique(batch))):
        xyz_c_b = xyz[batch == i]
        ret_new_cos.append(self.pose_ini(xyz_c_b))

    ret_cos = torch.cat(ret_new_cos, dim=0)
    input = input.replace_feature(torch.cat([input.features, ret_cos], dim=1))

    output = self.input_conv(input)
    output = self.unet(output, xyz, batch)
    output = self.output_layer(output)

    #### semantic segmentation
    semantic_scores = self.linear(output.features)   # (N, nClass), float
    return semantic_scores

我的配置文件是 DATA: data_name: semantic_kitti data_root: /home/hym/下载/SemanticKITTI/dataset label_mapping: util/semantic-kitti.yaml classes: 19 fea_dim: 6 voxel_size: [0.05, 0.05, 0.05] voxel_max: 120000

TRAIN:

arch

arch: unet_spherical_transformer input_c: 4 m: 32 block_reps: 2 block_residual: True layers: [32, 64, 128, 256, 256] focal_r: [ [ 3, 1, 1, 1 ],[ 3, 1, 1, 1 ],[ 3, 1, 1, 1 ],[ 3, 1, 1, 1 ],[ 3, 1, 1, 1 ] ] focal_th: [ [ 3, 1, 1, 1 ],[ 3, 1, 1, 1 ],[ 3, 1, 1, 1 ],[ 3, 1, 1, 1 ],[ 3, 1, 1, 1 ] ] focal_h: [ [ 3, 1, 1, 1 ],[ 3, 1, 1, 1 ],[ 3, 1, 1, 1 ],[ 3, 1, 1, 1 ],[ 3, 1, 1, 1 ] ] quant_size_scale: 24 patch_size: 1 window_size: 6 use_xyz: True sync_bn: True # adopt sync_bn or not rel_query: True rel_key: True rel_value: True drop_path_rate: 0.3 max_batch_points: 1000000 class_weight: [ 3.1557, 8.7029, 7.8281, 6.1354, 6.3161, 7.9937, 8.9704, 10.1922, 1.6155, 4.2187, 1.9385, 5.5455, 2.0198, 2.6261, 1.3212, 5.1102, 2.5492, 5.8585, 7.3929] xyz_norm: False pc_range: [[-51.2, -51.2, -4], [51.2, 51.2, 2.4]] window_size_sphere: [2, 2, 80] window_size_scale: [2.0, 1.5] sphere_layers: [1,2,3,4,5] grad_checkpoint_layers: [] a: 0.0125 loss_name: ce_loss use_tta: False vote_num: 4

training

aug: True transformer_lr_scale: 0.1 scheduler_update: step scheduler: Poly

power: 0.9 use_amp: True train_gpu: [0] workers: 16 # data loader workers batch_size: 8 # batch size for training batch_size_val: 8 # batch size for validation during training, memory and speed tradeoff base_lr: 0.006 epochs: 50 start_epoch: 0 momentum: 0.9 weight_decay: 0.02 drop_rate: 0.5

ignore_label: 255 manual_seed: 123 print_freq: 10 save_freq: 1 save_path: runs/semantic_kitti_unet32_spherical_transformer weight: # path to initial weight (default: none) resume: # path to latest checkpoint (default: none) evaluate: True # evaluate on validation set, extra gpu memory needed and small batch_size_val is recommend eval_freq: 1 val: False

Distributed: dist_url: tcp://127.0.0.1:6789 dist_backend: 'nccl' multiprocessing_distributed: True world_size: 1 rank: 0

我的训练文件基本与sphereformer一致,仅仅在生成模型时多添加了focal_th=args.focal_th,focal_r=args.focal_r,focal_h=args.focal_h,三个参数,十分感谢!

Cavendish518 commented 1 week ago

您好,我的代码现在能够跑起来,但是会在训练中会发生这样的提示,我不确定是否发生了梯度爆炸或是其他情况,您有遇到类似的情况解决经验可以分享给我吗: [11/21 20:19:10 main-logger]: Epoch: [1/50][1020/4782] Data 0.001 (0.006) Batch 0.883 (0.915) Remain 60:31:03 Loss nan Lr: [0.00597696, 0.0005977] Accuracy 0.0619. NaN or Inf found in input tensor. NaN or Inf found in input tensor. NaN or Inf found in input tensor.

这是我根据sphereformer的unet_spherical_transformer文件和您的backbone文件修改的网络骨干文件(我可能犯了很低级的错误。。): import functools import warnings import torch import torch.nn as nn import numpy as np import spconv.pytorch as spconv from spconv.pytorch.modules import SparseModule from spconv.core import ConvAlgo from collections import OrderedDict from torch_scatter import scatter_mean from model.spherical_transformer import SphereFormer

add

from model.SFPNet import ResFBlock

class ResidualBlock(SparseModule): def init(self, in_channels, out_channels, norm_fn, indice_key=None): super().init() if in_channels == out_channels: self.i_branch = spconv.SparseSequential( nn.Identity() ) else: self.i_branch = spconv.SparseSequential( spconv.SubMConv3d(in_channels, out_channels, kernel_size=1, bias=False) ) self.conv_branch = spconv.SparseSequential( norm_fn(in_channels), nn.ReLU(), spconv.SubMConv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=False, indice_key=indice_key), norm_fn(out_channels), nn.ReLU(), spconv.SubMConv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=False, indice_key=indice_key) )

def forward(self, input):
    identity = spconv.SparseConvTensor(input.features, input.indices, input.spatial_shape, input.batch_size)
    output = self.conv_branch(input)
    output = output.replace_feature(output.features + self.i_branch(identity).features)
    return output

class VGGBlock(SparseModule): def init(self, in_channels, out_channels, norm_fn, indice_key=None): super().init() self.conv_layers = spconv.SparseSequential( norm_fn(in_channels), nn.ReLU(), spconv.SubMConv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=False, indice_key=indice_key) )

def forward(self, input):
    return self.conv_layers(input)

def get_downsample_info(xyz, batch, indice_pairs): pair_in, pair_out = indice_pairs[0], indice_pairs[1] valid_mask = (pair_in != -1) valid_pair_in, valid_pair_out = pair_in[valid_mask].long(), pair_out[valid_mask].long() xyz_next = scatter_mean(xyz[valid_pair_in], index=valid_pair_out, dim=0) batch_next = scatter_mean(batch.float()[valid_pair_in], index=valid_pair_out, dim=0) return xyz_next, batch_next

class Cosine_aug(nn.Module): """ Frequency augmentation (optional) """

def __init__(self, in_dim=3, out_dim=6, alpha=10000, beta=1):
    super().__init__()
    self.in_dim = in_dim  # 输入维度,通常表示坐标的维度(如 3D 坐标)
    self.out_dim = out_dim  # 输出维度,表示生成的特征维度
    self.alpha, self.beta = alpha, beta  # alpha 和 beta 是用于特征缩放的参数

def forward(self, xyz):
    # 获取输入张量的当前设备(GPU 或 CPU)
    cur_dev = xyz.get_device()

    # 获取输入张量的形状,N 是点的数量,_ 是输入维度(应该为 in_dim)
    N, _ = xyz.shape

    # 计算特征维度,out_dim 是输出维度,in_dim 是输入维度
    feat_dim = self.out_dim // (self.in_dim * 2)

    # 创建一个范围张量,从 0 到 feat_dim-1
    feat_range = torch.arange(feat_dim).float().to(cur_dev)

    # 计算维度嵌入,使用 beta 的幂来缩放特征
    dim_embed = torch.pow(self.beta, feat_range / feat_dim)

    # 将 xyz 张量扩展一个维度,并进行缩放
    div_embed = torch.div(self.alpha * xyz.unsqueeze(-1), dim_embed)

    # 计算正弦嵌入
    sin_embed = torch.sin(div_embed)

    # 计算余弦嵌入
    cos_embed = torch.cos(div_embed)

    # 将正弦和余弦嵌入堆叠在一起,形成一个新的张量
    position_embed = torch.stack([sin_embed, cos_embed], dim=3).flatten(2)

    # 将嵌入的形状调整为 (N, out_dim),以便输出
    position_embed = position_embed.reshape(N, self.out_dim)

    # 返回生成的位置嵌入特征
    return position_embed

class UBlock(nn.Module): def init(self, nPlanes, focal_r, focal_th, focal_h, norm_fn, block_reps, block, window_size, window_size_sphere, quant_size, quant_size_sphere, head_dim=16, window_size_scale=[2.0, 2.0], rel_query=True, rel_key=True, rel_value=True, drop_path=0.0, indice_key_id=1, grad_checkpoint_layers=[], sphere_layers=[1,2,3,4,5], a=0.05*0.25, ctx_mode=0, ):

    super().__init__()
    # add
    focal_level = 3

    self.ctx_mode = ctx_mode  # 上下文模式
    self.nPlanes = nPlanes
    self.indice_key_id = indice_key_id
    self.grad_checkpoint_layers = grad_checkpoint_layers
    self.sphere_layers = sphere_layers

    blocks = {'block{}'.format(i): block(nPlanes[0], nPlanes[0], norm_fn, indice_key='subm{}'.format(indice_key_id)) for i in range(block_reps)}
    blocks = OrderedDict(blocks)
    self.blocks = spconv.SparseSequential(blocks)

    if indice_key_id in sphere_layers:

        # add
        self.focalblk = ResFBlock(nPlanes[0], drop=0., drop_path=drop_path[0], focal_level=focal_level,
                                  focal_x=focal_r[0],
                                  focal_y=focal_th[0], focal_z=focal_h[0],
                                  indice_key='focal{}'.format(indice_key_id))

        self.window_size = window_size
        self.window_size_sphere = window_size_sphere
        num_heads = nPlanes[0] // head_dim
        self.transformer_block = SphereFormer(
            nPlanes[0],
            num_heads,
            window_size,
            window_size_sphere,
            quant_size,
            quant_size_sphere,
            indice_key='sphereformer{}'.format(indice_key_id),
            rel_query=rel_query,
            rel_key=rel_key,
            rel_value=rel_value,
            drop_path=drop_path[0],
            a=a,
        )

    if len(nPlanes) > 1:
        self.conv = spconv.SparseSequential(
            norm_fn(nPlanes[0]),
            nn.ReLU(),
            spconv.SparseConv3d(nPlanes[0], nPlanes[1], kernel_size=2, stride=2, bias=False, indice_key='spconv{}'.format(indice_key_id), algo=ConvAlgo.Native)
        )

        window_size_scale_cubic, window_size_scale_sphere = window_size_scale
        window_size_next = np.array([
            window_size[0]*window_size_scale_cubic,
            window_size[1]*window_size_scale_cubic,
            window_size[2]*window_size_scale_cubic
        ])
        quant_size_next = np.array([
            quant_size[0]*window_size_scale_cubic,
            quant_size[1]*window_size_scale_cubic,
            quant_size[2]*window_size_scale_cubic
        ])
        window_size_sphere_next = np.array([
            window_size_sphere[0]*window_size_scale_sphere,
            window_size_sphere[1]*window_size_scale_sphere,
            window_size_sphere[2]
        ])
        quant_size_sphere_next = np.array([
            quant_size_sphere[0]*window_size_scale_sphere,
            quant_size_sphere[1]*window_size_scale_sphere,
            quant_size_sphere[2]
        ])

        self.u = UBlock(nPlanes[1:],
            focal_r[1:],
            focal_th[1:],
            focal_h[1:],
            norm_fn,
            block_reps,
            block,
            window_size_next,
            window_size_sphere_next,
            quant_size_next,
            quant_size_sphere_next,
            window_size_scale=window_size_scale,
            rel_query=rel_query,
            rel_key=rel_key,
            rel_value=rel_value,
            drop_path=drop_path[1:],
            indice_key_id=indice_key_id+1,
            grad_checkpoint_layers=grad_checkpoint_layers,
            sphere_layers=sphere_layers,
            a=a,
            ctx_mode=0,
        )

        self.deconv = spconv.SparseSequential(
            norm_fn(nPlanes[1]),
            nn.ReLU(),
            spconv.SparseInverseConv3d(nPlanes[1], nPlanes[0], kernel_size=2, bias=False, indice_key='spconv{}'.format(indice_key_id), algo=ConvAlgo.Native)
        )

        blocks_tail = {}
        for i in range(block_reps):
            blocks_tail['block{}'.format(i)] = block(nPlanes[0] * (2 - i), nPlanes[0], norm_fn, indice_key='subm{}'.format(indice_key_id))
        blocks_tail = OrderedDict(blocks_tail)
        self.blocks_tail = spconv.SparseSequential(blocks_tail)

def forward(self, inp, xyz, batch):

    assert (inp.indices[:, 0] == batch).all()

    output = self.blocks(inp)

    # transformer
    if self.indice_key_id in self.sphere_layers:
        '''
        if self.indice_key_id in self.grad_checkpoint_layers:
            def run(feats_, xyz_, batch_):
                return self.transformer_block(feats_, xyz_, batch_)
            transformer_features = torch.utils.checkpoint.checkpoint(run, output.features, xyz, batch)
        else:
            transformer_features = self.transformer_block(output.features, xyz, batch)
        output = output.replace_feature(transformer_features)
        '''
        # add
        output = self.focalblk(output)

    identity = spconv.SparseConvTensor(output.features, output.indices, output.spatial_shape, output.batch_size)

    if len(self.nPlanes) > 1:
        output_decoder = self.conv(output)

        # downsample
        indice_pairs = output_decoder.indice_dict['spconv{}'.format(self.indice_key_id)].indice_pairs
        xyz_next, batch_next = get_downsample_info(xyz, batch, indice_pairs)

        output_decoder = self.u(output_decoder, xyz_next, batch_next.long())
        output_decoder = self.deconv(output_decoder)
        output = output.replace_feature(torch.cat((identity.features, output_decoder.features), dim=1))
        output = self.blocks_tail(output)

    return output

class Semantic(nn.Module): def init(self, focal_th, focal_r, focal_h, input_c, m, classes, block_reps, block_residual, layers, window_size, window_size_sphere, quant_size, quant_size_sphere, rel_query=True, rel_key=True, rel_value=True, drop_path_rate=0.0, window_size_scale=2.0, grad_checkpoint_layers=[], sphere_layers=[1,2,3,4,5], a=0.05*0.25, ): super().init()

    norm_fn = functools.partial(nn.BatchNorm1d, eps=1e-4, momentum=0.1)

    if block_residual:
        block = ResidualBlock
    else:
        block = VGGBlock

    dpr = [x.item() for x in torch.linspace(0, drop_path_rate, 7)]

    #### backbone
    self.pose_ini = Cosine_aug(3, 6, 10000, 1)
    self.input_conv = spconv.SparseSequential(
        spconv.SubMConv3d(input_c+6, m, kernel_size=3, padding=1, bias=False, indice_key='subm1')
    )

    self.unet = UBlock(layers,
        focal_r,
        focal_th,
        focal_h,
        norm_fn,
        block_reps,
        block,
        window_size,
        window_size_sphere,
        quant_size,
        quant_size_sphere,
        window_size_scale=window_size_scale,
        rel_query=rel_query,
        rel_key=rel_key,
        rel_value=rel_value,
        drop_path=dpr,
        indice_key_id=1,
        grad_checkpoint_layers=grad_checkpoint_layers,
        sphere_layers=sphere_layers,
        a=a,
    )

    self.output_layer = spconv.SparseSequential(
        norm_fn(m),
        nn.ReLU()
    )
    self.apply(self.set_bn_init)
    #### semantic segmentation
    self.linear = nn.Linear(m, classes) # bias(default): True

@staticmethod
def set_bn_init(m):
    classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1:
        m.weight.data.fill_(1.0)
        m.bias.data.fill_(0.0)

def forward(self, input, xyz, batch):
    '''
    :param input_map: (N), int, cuda
    '''DATA:

data_name: semantic_kitti data_root: /home/hym/下载/SemanticKITTI/dataset label_mapping: util/semantic-kitti.yaml classes: 19 fea_dim: 6 voxel_size: [0.05, 0.05, 0.05] voxel_max: 120000

TRAIN:

arch

arch: unet_spherical_transformer input_c: 4 m: 32 block_reps: 2 block_residual: True layers: [32, 64, 128, 256, 256] quant_size_scale: 24 patch_size: 1 window_size: 6 use_xyz: True sync_bn: True # adopt sync_bn or not rel_query: True rel_key: True rel_value: True drop_path_rate: 0.3 max_batch_points: 1000000 class_weight: [ 3.1557, 8.7029, 7.8281, 6.1354, 6.3161, 7.9937, 8.9704, 10.1922, 1.6155, 4.2187, 1.9385, 5.5455, 2.0198, 2.6261, 1.3212, 5.1102, 2.5492, 5.8585, 7.3929] xyz_norm: False pc_range: [[-51.2, -51.2, -4], [51.2, 51.2, 2.4]] window_size_sphere: [2, 2, 80] window_size_scale: [2.0, 1.5] sphere_layers: [1,2,3,4,5] grad_checkpoint_layers: [] a: 0.0125 loss_name: ce_loss use_tta: False vote_num: 4

training

aug: True transformer_lr_scale: 0.1 scheduler_update: step scheduler: Poly

power: 0.9 use_amp: True train_gpu: [0] workers: 16 # data loader workers batch_size: 8 # batch size for training batch_size_val: 8 # batch size for validation during training, memory and speed tradeoff base_lr: 0.006 epochs: 50 start_epoch: 0 momentum: 0.9 weight_decay: 0.02 drop_rate: 0.5

ignore_label: 255 manual_seed: 123 print_freq: 10 save_freq: 1 save_path: runs/semantic_kitti_unet32_spherical_transformer weight: # path to initial weight (default: none) resume: # path to latest checkpoint (default: none) evaluate: True # evaluate on validation set, extra gpu memory needed and small batch_size_val is recommend eval_freq: 1 val: False

Distributed: dist_url: tcp://127.0.0.1:6789 dist_backend: 'nccl' multiprocessing_distributed: True world_size: 1 rank: 0

    ret_new_cos = []
    for i in range(len(torch.unique(batch))):
        xyz_c_b = xyz[batch == i]
        ret_new_cos.append(self.pose_ini(xyz_c_b))

    ret_cos = torch.cat(ret_new_cos, dim=0)
    input = input.replace_feature(torch.cat([input.features, ret_cos], dim=1))

    output = self.input_conv(input)
    output = self.unet(output, xyz, batch)
    output = self.output_layer(output)

    #### semantic segmentation
    semantic_scores = self.linear(output.features)   # (N, nClass), float
    return semantic_scores

我的配置文件是 DATA: data_name: semantic_kitti data_root: /home/hym/下载/SemanticKITTI/dataset label_mapping: util/semantic-kitti.yaml classes: 19 fea_dim: 6 voxel_size: [0.05, 0.05, 0.05] voxel_max: 120000

TRAIN:

arch

arch: unet_spherical_transformer input_c: 4 m: 32 block_reps: 2 block_residual: True layers: [32, 64, 128, 256, 256] focal_r: [ [ 3, 1, 1, 1 ],[ 3, 1, 1, 1 ],[ 3, 1, 1, 1 ],[ 3, 1, 1, 1 ],[ 3, 1, 1, 1 ] ] focal_th: [ [ 3, 1, 1, 1 ],[ 3, 1, 1, 1 ],[ 3, 1, 1, 1 ],[ 3, 1, 1, 1 ],[ 3, 1, 1, 1 ] ] focal_h: [ [ 3, 1, 1, 1 ],[ 3, 1, 1, 1 ],[ 3, 1, 1, 1 ],[ 3, 1, 1, 1 ],[ 3, 1, 1, 1 ] ] quant_size_scale: 24 patch_size: 1 window_size: 6 use_xyz: True sync_bn: True # adopt sync_bn or not rel_query: True rel_key: True rel_value: True drop_path_rate: 0.3 max_batch_points: 1000000 class_weight: [ 3.1557, 8.7029, 7.8281, 6.1354, 6.3161, 7.9937, 8.9704, 10.1922, 1.6155, 4.2187, 1.9385, 5.5455, 2.0198, 2.6261, 1.3212, 5.1102, 2.5492, 5.8585, 7.3929] xyz_norm: False pc_range: [[-51.2, -51.2, -4], [51.2, 51.2, 2.4]] window_size_sphere: [2, 2, 80] window_size_scale: [2.0, 1.5] sphere_layers: [1,2,3,4,5] grad_checkpoint_layers: [] a: 0.0125 loss_name: ce_loss use_tta: False vote_num: 4

training

aug: True transformer_lr_scale: 0.1 scheduler_update: step scheduler: Poly

power: 0.9 use_amp: True train_gpu: [0] workers: 16 # data loader workers batch_size: 8 # batch size for training batch_size_val: 8 # batch size for validation during training, memory and speed tradeoff base_lr: 0.006 epochs: 50 start_epoch: 0 momentum: 0.9 weight_decay: 0.02 drop_rate: 0.5

ignore_label: 255 manual_seed: 123 print_freq: 10 save_freq: 1 save_path: runs/semantic_kitti_unet32_spherical_transformer weight: # path to initial weight (default: none) resume: # path to latest checkpoint (default: none) evaluate: True # evaluate on validation set, extra gpu memory needed and small batch_size_val is recommend eval_freq: 1 val: False

Distributed: dist_url: tcp://127.0.0.1:6789 dist_backend: 'nccl' multiprocessing_distributed: True world_size: 1 rank: 0

我的训练文件基本与sphereformer一致,仅仅在生成模型时多添加了focal_th=args.focal_th,focal_r=args.focal_r,focal_h=args.focal_h,三个参数,十分感谢!

demo.py 提供了完整的推理代码,backbone.py提供了完整的网络架构供您参考,超参数您可以参考issue#1. 建议您耐心阅读论文和补充材料,仔细查看完整的代码,注释和issue。另外相关工作的代码也强烈建议您花一些时间理解一下,希望能够帮到您,谢谢!