Pointcept / PointTransformerV3

[CVPR'24 Oral] Official repository of Point Transformer V3 (PTv3)
MIT License
837 stars 47 forks source link

A question about use " MSE loss" #59

Open QWTforGithub opened 5 months ago

QWTforGithub commented 5 months ago

Thank you very much for sharing the well organized code. I am trying to build a work on your code. I found that the code worked fine with the cross entropy loss you defined earlier, but I got the error "RuntimeError: CUDA error: device-side assert triggered" when I used my own MSE loss:

def ignore_label(scores, labels, ignore=None, valid=None, scene_id=None):
    if ignore is None:
        return scores, labels
    if(valid is None):
        valid = labels != ignore
    vscores = scores[valid]
    vlabels = labels[valid]
    return vscores, vlabels

@LOSSES.register_module('MSELoss')
class MSELoss(nn.Module):
    def __init__(
        self,
        pred="c_pred",
        target="c_target",
        valid_index=None,
        size_average=None,
        reduce=None,
        reduction="mean",
        loss_weight=1.0,
        ignore_index=20
    ):
        super(MSELoss, self).__init__()
        self.loss_weight = loss_weight
        self.ignore_index = ignore_index
        self.pred = pred
        self.target = target
        self.valid_index = valid_index
        self.loss = nn.MSELoss(
            size_average=size_average,
            reduce=reduce,
            reduction=reduction
        )

    def forward(self, points):
        pred = points[self.pred]
        target = points[self.target]
        if(self.valid_index):
            valid = points[self.valid_index]
            pred, target = ignore_label(pred,target,self.ignore_index,valid)

        loss = self.loss(pred, target) * self.loss_weight
        return loss

I tried to split the data from 102400 to 51200, and even 25600, and still got this error. The specific error is as follows: [06/05 08:09:28 pointcept]: Train: [1/100][599/4800] Data 0.003 (0.004) Batch 0.317 (0.336) Remain 44:48:37 loss: 0.0409 Lr: 0.00061 [06/05 08:09:28 pointcept]: Train: [1/100][600/4800] Data 0.003 (0.004) Batch 0.368 (0.337) Remain 44:49:02 loss: 0.0404 Lr: 0.00061 [06/05 08:09:29 pointcept]: Train: [1/100][601/4800] Data 0.005 (0.004) Batch 0.409 (0.337) Remain 44:49:59 loss: 0.0343 Lr: 0.00061 [06/05 08:09:29 pointcept]: Train: [1/100][602/4800] Data 0.006 (0.004) Batch 0.406 (0.337) Remain 44:50:55 loss: 0.0352 Lr: 0.00061 [06/05 08:09:29 pointcept]: Train: [1/100][603/4800] Data 0.004 (0.004) Batch 0.377 (0.337) Remain 44:51:26 loss: 0.0372 Lr: 0.00061 [06/05 08:09:30 pointcept]: Train: [1/100][604/4800] Data 0.005 (0.004) Batch 0.379 (0.337) Remain 44:52:00 loss: 0.0356 Lr: 0.00061 [06/05 08:09:30 pointcept]: Train: [1/100][605/4800] Data 0.004 (0.004) Batch 0.328 (0.337) Remain 44:51:52 loss: 0.0568 Lr: 0.00061 [06/05 08:09:30 pointcept]: Train: [1/100][606/4800] Data 0.003 (0.004) Batch 0.285 (0.337) Remain 44:51:11 loss: 0.0379 Lr: 0.00061 [06/05 08:09:31 pointcept]: Train: [1/100][607/4800] Data 0.003 (0.004) Batch 0.341 (0.337) Remain 44:51:13 loss: 0.0338 Lr: 0.00061 [06/05 08:09:31 pointcept]: Train: [1/100][608/4800] Data 0.006 (0.004) Batch 0.314 (0.337) Remain 44:50:55 loss: 0.0781 Lr: 0.00061 [06/05 08:09:31 pointcept]: Train: [1/100][609/4800] Data 0.003 (0.004) Batch 0.341 (0.337) Remain 44:50:59 loss: 0.0640 Lr: 0.00061 [06/05 08:09:32 pointcept]: Train: [1/100][610/4800] Data 0.003 (0.004) Batch 0.303 (0.337) Remain 44:50:31 loss: 0.0464 Lr: 0.00061 [06/05 08:09:32 pointcept]: Train: [1/100][611/4800] Data 0.004 (0.004) Batch 0.319 (0.337) Remain 44:50:17 loss: 0.0370 Lr: 0.00061 [06/05 08:09:32 pointcept]: Train: [1/100][612/4800] Data 0.004 (0.004) Batch 0.304 (0.337) Remain 44:49:51 loss: 0.0317 Lr: 0.00061 [06/05 08:09:33 pointcept]: Train: [1/100][613/4800] Data 0.004 (0.004) Batch 0.314 (0.337) Remain 44:49:32 loss: 0.2427 Lr: 0.00061 /opt/conda/conda-bld/pytorch_1695392020195/work/aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [97,0,0], thread: [80,0,0] Assertion -sizes[i] <= index && index < sizes[i] && "index out of bounds" failed. /opt/conda/conda-bld/pytorch_1695392020195/work/aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [97,0,0], thread: [81,0,0] Assertion -sizes[i] <= index && index < sizes[i] && "index out of bounds" failed. /opt/conda/conda-bld/pytorch_1695392020195/work/aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [97,0,0], thread: [82,0,0] Assertion -sizes[i] <= index && index < sizes[i] && "index out of bounds" failed. /opt/conda/conda-bld/pytorch_1695392020195/work/aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [97,0,0], thread: [83,0,0] Assertion -sizes[i] <= index && index < sizes[i] && "index out of bounds" failed. /opt/conda/conda-bld/pytorch_1695392020195/work/aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [97,0,0], thread: [84,0,0] Assertion -sizes[i] <= index && index < sizes[i] && "index out of bounds" failed. /opt/conda/conda-bld/pytorch_1695392020195/work/aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [97,0,0], thread: [85,0,0] Assertion -sizes[i] <= index && index < sizes[i] && "index out of bounds" failed. /opt/conda/conda-bld/pytorch_1695392020195/work/aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [97,0,0], thread: [86,0,0] Assertion -sizes[i] <= index && index < sizes[i] && "index out of bounds" failed. /opt/conda/conda-bld/pytorch_1695392020195/work/aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [97,0,0], thread: [87,0,0] Assertion -sizes[i] <= index && index < sizes[i] && "index out of bounds" failed. /opt/conda/conda-bld/pytorch_1695392020195/work/aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [97,0,0], thread: [88,0,0] Assertion -sizes[i] <= index && index < sizes[i] && "index out of bounds" failed. /opt/conda/conda-bld/pytorch_1695392020195/work/aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [97,0,0], thread: [89,0,0] Assertion -sizes[i] <= index && index < sizes[i] && "index out of bounds" failed. /opt/conda/conda-bld/pytorch_1695392020195/work/aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [97,0,0], thread: [90,0,0] Assertion -sizes[i] <= index && index < sizes[i] && "index out of bounds" failed. /opt/conda/conda-bld/pytorch_1695392020195/work/aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [97,0,0], thread: [91,0,0] Assertion -sizes[i] <= index && index < sizes[i] && "index out of bounds" failed. /opt/conda/conda-bld/pytorch_1695392020195/work/aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [97,0,0], thread: [92,0,0] Assertion -sizes[i] <= index && index < sizes[i] && "index out of bounds" failed. /opt/conda/conda-bld/pytorch_1695392020195/work/aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [97,0,0], thread: [93,0,0] Assertion -sizes[i] <= index && index < sizes[i] && "index out of bounds" failed. /opt/conda/conda-bld/pytorch_1695392020195/work/aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [97,0,0], thread: [94,0,0] Assertion -sizes[i] <= index && index < sizes[i] && "index out of bounds" failed. /opt/conda/conda-bld/pytorch_1695392020195/work/aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [97,0,0], thread: [95,0,0] Assertion -sizes[i] <= index && index < sizes[i] && "index out of bounds" failed. Traceback (most recent call last): File "/root/models/Pointcept_DM5/tools/train.py", line 68, in main() File "/root/models/Pointcept_DM5/tools/train.py", line 55, in main launch( File "/root/models/Pointcept_DM5/pointcept/engines/launch.py", line 89, in launch main_func(cfg) # 单卡训练 File "/root/models/Pointcept_DM5/tools/train.py", line 27, in main_worker trainer.train() File "/root/models/Pointcept_DM5/pointcept/engines/train.py", line 165, in train self.run_step() File "/root/models/Pointcept_DM5/pointcept/engines/train.py", line 236, in run_step output_dict = self.model(input_dict,input_dict) File "/root/packages/anconda3/envs/pt3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/root/packages/anconda3/envs/pt3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, kwargs) File "/root/models/Pointcept_DM5/pointcept/models/default.py", line 572, in forward point['c_pred'] = point['c_pred'][valid] RuntimeError: CUDA error: device-side assert triggered CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1. Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

This problem is very troubling to me, can you give me some suggestions? Thank you.

Gofinge commented 5 months ago

Oh, I also encounter a similar issue in some settings. I suspect it is caused by pooling. I rewrote it in my development version, but still have not boarder verified. I can attach the experimental version to this issue tomorrow.

Gofinge commented 5 months ago

May I confirm which model you are using?

QWTforGithub commented 5 months ago

May I confirm which model you are using?

Thank you for you reply. I am using version 1.5.2 Pointcept (downloaded directly from GitHub). However, in practice, I used the output (N, 3) of Point Transformer V3 (channel 20 becomes channel 3) to align the original coordinates (N, 3) by MSE, which is not similar problem occurred. I modified the code and then this error occurred. I don't think I changed much, I just replaced the input features. Would you like to help me look at the code? If possible, I can send my current code to your email. Thank you.

QWTforGithub commented 5 months ago

May I confirm which model you are using?

Thank you for you reply. I am using version 1.5.2 Pointcept (downloaded directly from GitHub). However, in practice, I used the output (N, 3) of Point Transformer V3 (channel 20 becomes channel 3) to align the original coordinates (N, 3) by MSE, which is not similar problem occurred. I modified the code and then this error occurred. I don't think I changed much, I just replaced the input features. Would you like to help me look at the code? If possible, I can send my current code to your email. Thank you.

If possible, please leave your email address where you can be contacted. Thank you.

Lizhinwafu commented 5 months ago

I also have this error. Maybe is Loss? I have 2 datasets. One dataset is OK, another dataset is not work.

Lizhinwafu commented 5 months ago

image

QWTforGithub commented 5 months ago

May I confirm which model you are using?

Or you can try this: concatenate the coordinates (N,3) with a random variable (N,3) is as input to the network. The output is aligned to the original coordinates (N,3) by MSE. Similar errors will occur after training for several epochs on ScanNet.

Gofinge commented 5 months ago

I am also not sure, but try to replace the pooling code with the following (not rely on code, but rely on grid_coord)

class SerializedPooling(PointModule):
    def __init__(
        self,
        in_channels,
        out_channels,
        stride=2,
        norm_layer=None,
        act_layer=None,
        reduce="max",
        shuffle_orders=True,
        traceable=True,  # record parent and cluster
    ):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.stride = stride
        assert reduce in ["sum", "mean", "min", "max"]
        self.reduce = reduce
        self.shuffle_orders = shuffle_orders
        self.traceable = traceable

        self.proj = nn.Linear(in_channels, out_channels)
        if norm_layer is not None:
            self.norm = PointSequential(norm_layer(out_channels))
        if act_layer is not None:
            self.act = PointSequential(act_layer())

    def serialized_forward(self, point: Point):
        pooling_depth = (math.ceil(self.stride) - 1).bit_length()
        if pooling_depth > point.serialized_depth:
            pooling_depth = 0
        assert {
            "serialized_code",
            "serialized_order",
            "serialized_inverse",
            "serialized_depth",
        }.issubset(
            point.keys()
        ), "Run point.serialization() point cloud before SerializedPooling"

        code = point.serialized_code >> pooling_depth * 3
        code_, cluster, counts = torch.unique(
            code[0],
            sorted=True,
            return_inverse=True,
            return_counts=True,
        )
        # indices of point sorted by cluster, for torch_scatter.segment_csr
        _, indices = torch.sort(cluster)
        # index pointer for sorted point, for torch_scatter.segment_csr
        idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)])
        # head_indices of each cluster, for reduce attr e.g. code, batch
        head_indices = indices[idx_ptr[:-1]]
        # generate down code, order, inverse
        code = code[:, head_indices]
        order = torch.argsort(code)
        inverse = torch.zeros_like(order).scatter_(
            dim=1,
            index=order,
            src=torch.arange(0, code.shape[1], device=order.device).repeat(
                code.shape[0], 1
            ),
        )

        if self.shuffle_orders:
            perm = torch.randperm(code.shape[0])
            code = code[perm]
            order = order[perm]
            inverse = inverse[perm]

        # collect information
        point_dict = Dict(
            feat=torch_scatter.segment_csr(
                self.proj(point.feat)[indices], idx_ptr, reduce=self.reduce
            ),
            coord=torch_scatter.segment_csr(
                point.coord[indices], idx_ptr, reduce="mean"
            ),
            grid_coord=point.grid_coord[head_indices] >> pooling_depth,
            serialized_code=code,
            serialized_order=order,
            serialized_inverse=inverse,
            serialized_depth=point.serialized_depth - pooling_depth,
            batch=point.batch[head_indices],
        )

        if "condition" in point.keys():
            point_dict["condition"] = point.condition
        if "context" in point.keys():
            point_dict["context"] = point.context

        if self.traceable:
            point_dict["pooling_inverse"] = cluster
            point_dict["pooling_parent"] = point
        point = Point(point_dict)
        if self.norm is not None:
            point = self.norm(point)
        if self.act is not None:
            point = self.act(point)
        point.sparsify()
        return point

    def grid_forward(self, point: Point):
        if "grid_coord" in point.keys():
            grid_coord = point.grid_coord
        elif {"coord", "grid_size"}.issubset(point.keys()):
            grid_coord = torch.div(
                point.coord - point.coord.min(0)[0],
                point.grid_size,
                rounding_mode="trunc",
            ).int()
        else:
            raise AssertionError(
                "[gird_coord] or [coord, grid_size] should be include in the Point"
            )
        grid_coord = torch.div(grid_coord, self.stride, rounding_mode="trunc")
        grid_coord, cluster, counts = torch.unique(
            grid_coord,
            sorted=True,
            return_inverse=True,
            return_counts=True,
            dim=0,
        )
        # indices of point sorted by cluster, for torch_scatter.segment_csr
        _, indices = torch.sort(cluster)
        # index pointer for sorted point, for torch_scatter.segment_csr
        idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)])
        # head_indices of each cluster, for reduce attr e.g. code, batch
        head_indices = indices[idx_ptr[:-1]]
        point_dict = Dict(
            feat=torch_scatter.segment_csr(
                self.proj(point.feat)[indices], idx_ptr, reduce=self.reduce
            ),
            coord=torch_scatter.segment_csr(
                point.coord[indices], idx_ptr, reduce="mean"
            ),
            grid_coord=grid_coord,
            batch=point.batch[head_indices],
        )
        if "condition" in point.keys():
            point_dict["condition"] = point.condition
        if "context" in point.keys():
            point_dict["context"] = point.context

        if self.traceable:
            point_dict["pooling_inverse"] = cluster
            point_dict["pooling_parent"] = point
        order = point.order
        point = Point(point_dict)
        if self.norm is not None:
            point = self.norm(point)
        if self.act is not None:
            point = self.act(point)
        point.serialization(order=order, shuffle_orders=self.shuffle_orders)
        point.sparsify()
        return point

    def forward(self, point: Point):
        # if self.stride == 2 ** (math.ceil(self.stride) - 1).bit_length():
        #     return self.serialized_forward(point)
        # else:
        #     return self.grid_forward(point)
        return self.grid_forward(point)

Give me a feedback if this work

QWTforGithub commented 5 months ago

I am also not sure, but try to replace the pooling code with the following (not rely on code, but rely on grid_coord)

Thank you very much. I will now replace this code and try it.

QWTforGithub commented 5 months ago

I am also not sure, but try to replace the pooling code with the following (not rely on code, but rely on grid_coord)

Thank you very much. I will now replace this code and try it. I found the following problems: image image

This 'code' to have an empty list.

Gofinge commented 5 months ago

Add this line in Point.serialization()

image

QWTforGithub commented 5 months ago

Add this line in Point.serialization()

image

Thank you very much. It can be run. However, it may take several epochs before similar errors appear. So, we need to wait for a while.

Lizhinwafu commented 5 months ago

Now, I can train my own data. I did not change the code mentioned above, I just changed the category of my data. I don't know why. The model I chose is: semseg-pt-v3m1-0-rpe.py. (I didn't change the loss, I used the loss in the source code, there seem to be two: cross entropy loss and LovaszLoss.)

QWTforGithub commented 5 months ago

I am also not sure, but try to replace the pooling code with the following (not rely on code, but rely on grid_coord)

class SerializedPooling(PointModule):
    def __init__(
        self,
        in_channels,
        out_channels,
        stride=2,
        norm_layer=None,
        act_layer=None,
        reduce="max",
        shuffle_orders=True,
        traceable=True,  # record parent and cluster
    ):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.stride = stride
        assert reduce in ["sum", "mean", "min", "max"]
        self.reduce = reduce
        self.shuffle_orders = shuffle_orders
        self.traceable = traceable

        self.proj = nn.Linear(in_channels, out_channels)
        if norm_layer is not None:
            self.norm = PointSequential(norm_layer(out_channels))
        if act_layer is not None:
            self.act = PointSequential(act_layer())

    def serialized_forward(self, point: Point):
        pooling_depth = (math.ceil(self.stride) - 1).bit_length()
        if pooling_depth > point.serialized_depth:
            pooling_depth = 0
        assert {
            "serialized_code",
            "serialized_order",
            "serialized_inverse",
            "serialized_depth",
        }.issubset(
            point.keys()
        ), "Run point.serialization() point cloud before SerializedPooling"

        code = point.serialized_code >> pooling_depth * 3
        code_, cluster, counts = torch.unique(
            code[0],
            sorted=True,
            return_inverse=True,
            return_counts=True,
        )
        # indices of point sorted by cluster, for torch_scatter.segment_csr
        _, indices = torch.sort(cluster)
        # index pointer for sorted point, for torch_scatter.segment_csr
        idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)])
        # head_indices of each cluster, for reduce attr e.g. code, batch
        head_indices = indices[idx_ptr[:-1]]
        # generate down code, order, inverse
        code = code[:, head_indices]
        order = torch.argsort(code)
        inverse = torch.zeros_like(order).scatter_(
            dim=1,
            index=order,
            src=torch.arange(0, code.shape[1], device=order.device).repeat(
                code.shape[0], 1
            ),
        )

        if self.shuffle_orders:
            perm = torch.randperm(code.shape[0])
            code = code[perm]
            order = order[perm]
            inverse = inverse[perm]

        # collect information
        point_dict = Dict(
            feat=torch_scatter.segment_csr(
                self.proj(point.feat)[indices], idx_ptr, reduce=self.reduce
            ),
            coord=torch_scatter.segment_csr(
                point.coord[indices], idx_ptr, reduce="mean"
            ),
            grid_coord=point.grid_coord[head_indices] >> pooling_depth,
            serialized_code=code,
            serialized_order=order,
            serialized_inverse=inverse,
            serialized_depth=point.serialized_depth - pooling_depth,
            batch=point.batch[head_indices],
        )

        if "condition" in point.keys():
            point_dict["condition"] = point.condition
        if "context" in point.keys():
            point_dict["context"] = point.context

        if self.traceable:
            point_dict["pooling_inverse"] = cluster
            point_dict["pooling_parent"] = point
        point = Point(point_dict)
        if self.norm is not None:
            point = self.norm(point)
        if self.act is not None:
            point = self.act(point)
        point.sparsify()
        return point

    def grid_forward(self, point: Point):
        if "grid_coord" in point.keys():
            grid_coord = point.grid_coord
        elif {"coord", "grid_size"}.issubset(point.keys()):
            grid_coord = torch.div(
                point.coord - point.coord.min(0)[0],
                point.grid_size,
                rounding_mode="trunc",
            ).int()
        else:
            raise AssertionError(
                "[gird_coord] or [coord, grid_size] should be include in the Point"
            )
        grid_coord = torch.div(grid_coord, self.stride, rounding_mode="trunc")
        grid_coord, cluster, counts = torch.unique(
            grid_coord,
            sorted=True,
            return_inverse=True,
            return_counts=True,
            dim=0,
        )
        # indices of point sorted by cluster, for torch_scatter.segment_csr
        _, indices = torch.sort(cluster)
        # index pointer for sorted point, for torch_scatter.segment_csr
        idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)])
        # head_indices of each cluster, for reduce attr e.g. code, batch
        head_indices = indices[idx_ptr[:-1]]
        point_dict = Dict(
            feat=torch_scatter.segment_csr(
                self.proj(point.feat)[indices], idx_ptr, reduce=self.reduce
            ),
            coord=torch_scatter.segment_csr(
                point.coord[indices], idx_ptr, reduce="mean"
            ),
            grid_coord=grid_coord,
            batch=point.batch[head_indices],
        )
        if "condition" in point.keys():
            point_dict["condition"] = point.condition
        if "context" in point.keys():
            point_dict["context"] = point.context

        if self.traceable:
            point_dict["pooling_inverse"] = cluster
            point_dict["pooling_parent"] = point
        order = point.order
        point = Point(point_dict)
        if self.norm is not None:
            point = self.norm(point)
        if self.act is not None:
            point = self.act(point)
        point.serialization(order=order, shuffle_orders=self.shuffle_orders)
        point.sparsify()
        return point

    def forward(self, point: Point):
        # if self.stride == 2 ** (math.ceil(self.stride) - 1).bit_length():
        #     return self.serialized_forward(point)
        # else:
        #     return self.grid_forward(point)
        return self.grid_forward(point)

Give me a feedback if this work

Thank you very much. Now the code can train normally. At the same time, I found that this pooling did not affect performance.

Gofinge commented 5 months ago

Fundamentally, the two pooling strategies are the same. The former relies on existing serialization codes. I will change the pooling code to the later one as some other researcher want to change the default serialization pattern, but it will also affect the pooling process.

jotix16 commented 4 months ago

Hi @Gofinge,

the grid_forward wrongly pools points from different batches. The problem appears after the torch.unique which doesn't know about the batch.

It could be easily solved by pre-shifting the batch indexes and removing them later:

        ....
        grid_coord = grid_coord | point.batch.view(-1, 1) << 16
        grid_coord, cluster, counts = torch.unique(
            grid_coord,
            sorted=True,
            return_inverse=True,
            return_counts=True,
            dim=0,
        )
        grid_coord = grid_coord & 0xFFFF
        ....
Gofinge commented 4 months ago

torch.unique

In which code do you mean? I checked the PTv3 and PTv2, and we did consider the batch thing.

jotix16 commented 4 months ago

Ah sorry,

my comment was about your suggestion in this issue:

As far as I know, there is no grid-pooling for ptv3, or am I wrong?

Gofinge commented 4 months ago

@jotix16 Oh, yes. I remember that now. Thanks for your reminding. BTW: Fundermantially Pooling used in PTv3 is GridPool. The two implementations are actually the same.

jotix16 commented 4 months ago

Exactly, but as you mentioned somewhere, in SerializedPooling you can scale the grid only with powers of two.

GridPooling is more flexible but you have to re-serialize.

Thank you for getting back to me so quickly, I just added the comment in case someone needs to use the grid pooling from this issue.

QWTforGithub commented 3 months ago

Exactly, but as you mentioned somewhere, in SerializedPooling you can scale the grid only with powers of two.

GridPooling is more flexible but you have to re-serialize.

Thank you for getting back to me so quickly, I just added the comment in case someone needs to use the grid pooling from this issue.

Hi, have you used PT v3 to train on a single 24G GPU? I set batch size=2, num work=4, and kept everything else the same. I found that the performance of PT v3 fluctuated greatly (76.10~77.23). I fixed the random seed. But the loss results are not consistent. Do you know why this is? How can I improve or set it up to make the performance of PT v3 stable?