nickgkan / butd_detr

Code for the ECCV22 paper "Bottom Up Top Down Detection Transformers for Language Grounding in Images and Point Clouds"
Other
74 stars 11 forks source link

Issue regarding the adaptation of butd-detr for multiview images input #28

Closed Hiusam closed 1 year ago

Hiusam commented 1 year ago

Hello Authors,

Firstly, I would like to extend my sincere appreciation for your exceptional work on "Bottom Up Top Down Detection Transformers for Language Grounding in Images and Point Clouds". Your work has greatly inspired me, and I am currently working on modifying your model, butd-detr, to accommodate multiview images as input while discarding point clouds.

In my modification, I'm loading multiview images and using ResNet50 to extract image features. These features are backprojected to uniformly sampled positions in the scene space to construct 3D voxels, as implemented in ImVoxelNet (https://github.com/SamsungLabs/imvoxelnet). I have replaced the original visual stream (pointnet++) with this backbone while leaving the rest of the code unchanged.

One difference is in the compute_points_obj_cls_loss_hard_topk function where the voxels (or positions I sampled in the space) will not belong to the objects' point clouds. So, I've assigned the nearest four voxels (positions) to the ground truth center as positive.

Additionally, I have chosen not to use the box stream since my goal is to make the model image-based only.

Despite my efforts, I've been facing an issue where the performance of the model is consistently zero post-training. I've been trying to troubleshoot this problem and I am reaching out to you to ask if I may be missing something critical.

Any insights or suggestions you could provide would be of immense help. Thank you very much for your time and consideration.

Best regards,

Hiusam

Hiusam commented 1 year ago

For reference, the codes of the backbone are:

    def _run_backbones(self, inputs):
        """Run visual and text backbones.
            inputs contain:
                'multiview_imgs': Tensor(B, n_images, 3, H, W)
                'extrinsics': Tensor(B, n_images, 4, 4)
                'intrinsics': Tensor(B, 4, 4)

        """
        end_points = {}
        # * multiview_imgs encoder
        if self.use_multiview_image:
            x, points = self._imvoxelnet_backbones(inputs)
            end_points = {
                'seed_xyz': points,
                'seed_features': x,
                'use_multiview_image': True
            }
        # Point encoder
        else:
            end_points_old = self.backbone_net(inputs['point_clouds'], end_points={})

            end_points['seed_inds'] = end_points_old['fp2_inds'] # * B, 1024
            end_points['seed_xyz'] = end_points_old['fp2_xyz'] # * B, 1204, 3
            end_points['seed_features'] = end_points_old['fp2_features'] # * B, C, 1024(point number) 
            end_points['use_multiview_image'] = False
        # Text encoder
        tokenized = self.tokenizer.batch_encode_plus(
            inputs['text'], padding="longest", return_tensors="pt"
        ).to(inputs['point_clouds'].device)
        encoded_text = self.text_encoder(**tokenized)
        text_feats = self.text_projector(encoded_text.last_hidden_state)
        # Invert attention mask that we get from huggingface
        # because its the opposite in pytorch transformer
        text_attention_mask = tokenized.attention_mask.ne(1).bool()
        end_points['text_feats'] = text_feats
        end_points['text_attention_mask'] = text_attention_mask
        end_points['tokenized'] = tokenized
        return end_points

    def _imvoxelnet_backbones(self, inputs):
        """
        ImVoxelNet Backbone
            inputs contain:
                'multiview_imgs': Tensor(B, n_images, 3, H, W)
                'extrinsics': Tensor(B, n_images, 4, 4)
                'intrinsics': Tensor(B, 4, 4)
                'origin': Tensor(B, 3)
                'img_shape': Tensor(B, 3)
        """
        img = inputs['multiview_imgs']
        extrinsics = inputs['extrinsics']
        intrinsics = inputs['intrinsic']
        origins = inputs['origin']
        img_shapes = inputs['img_shape']
        ori_shapes = inputs['ori_shape']
        scale_factors = inputs['scale_factor']

        # * extract features from each view
        batch_size, num_views, C, img_h, img_w = img.shape
        img = img.reshape((-1, C, img_h, img_w))

        x = self.backbone(img)
        x = self.neck(x)[0]
        x = x.reshape((batch_size, num_views, *x.shape[1:]))

        stride = img.shape[-1] / x.shape[-1]
        assert stride == 4  # may be removed in the future
        stride = int(stride)

        volumes, batch_points = [], []
        for batch_id, feature in enumerate(x):
            img_meta = {
                'lidar2img': {
                    'origin': origins[batch_id],
                    'extrinsic': extrinsics[batch_id],
                    'intrinsic': intrinsics[batch_id],
                },
                'img_shape': img_shapes[batch_id],
                'ori_shape': ori_shapes[batch_id],
                'scale_factor': scale_factors[batch_id]
            }
            projection = self._compute_projection(img_meta, stride).to(x.device)
            points = self.get_points(
                n_voxels=torch.tensor(self.n_voxels, device=x.device), # * n_voxels: x, y, z
                voxel_size=torch.tensor(self.voxel_size,device=x.device), # * voxel_size: x, y, z
                origin=img_meta['lidar2img']['origin']
            )
            height = torch.div(img_meta['img_shape'][0], stride, rounding_mode='trunc')
            width = torch.div(img_meta['img_shape'][1], stride, rounding_mode='trunc')
            volume, valid = self.backproject(feature[:, :, :height, :width], points, projection)
            volume = volume.sum(dim=0)
            valid = valid.sum(dim=0)
            volume = volume / valid
            valid = valid > 0
            volume[:, ~valid[0]] = .0
            volumes.append(volume)
            batch_points.append(points)
        x = torch.stack(volumes) # * x to be [B, C, n_voxel_x, n_voxel_y, n_voxel_z]
        batch_points = torch.stack(batch_points)  # * points to be [B, 3, n_voxel_x, n_voxel_y, n_voxel_z]
        x = self.neck_3d(x)  # * x to to list of tensors with [B, C2, n_voxel_x/n, n_voxel_y/n, n_voxel_z/n] (n=1,2,4)
        # * need to return voxel positions
        return x[0].reshape((batch_size, x[0].shape[1], -1)), batch_points.reshape((batch_size, 3, -1)).transpose(1, 2)

And the codes for compute_points_obj_cls_loss_hard_topk are:

def compute_points_obj_cls_loss_hard_topk(end_points, topk):
    """
    For reference:
        point_instance_label = -np.ones(len(scan.pc))
        for t, tid in enumerate(tids):
            point_instance_label[scan.three_d_objects[tid]['points']] = t # * note here not the object id (tid), but the index of the object in the scene
    """
    box_label_mask = end_points['box_label_mask'] # B, G(max_num_objects)
    seed_xyz = end_points['seed_xyz']  # B, K, 3
    seeds_obj_cls_logits = end_points['seeds_obj_cls_logits']  # B, 1, K
    gt_center = end_points['center_label'][:, :, :3]  # B, G, 3
    gt_size = end_points['size_gts'][:, :, :3]  # B, G, 3
    B = gt_center.shape[0]  # batch size
    K = seed_xyz.shape[1]  # number of of points from p++ output
    G = gt_center.shape[1]  # number of gt boxes (with padding)

    # Assign each point to a GT object
    # * find the object index of each seed points, for those points without object, assign to G-1
    # * for those points with target object, assign 0, 1, ...
    if not end_points['use_multiview_image']:
        seed_inds = end_points['seed_inds'].long()  # B, K 
        point_instance_label = end_points['point_instance_label']  # B, num_points # * of target objects only
        obj_assignment = torch.gather(point_instance_label, 1, seed_inds)  # B, K # * find the object index of the selected points
        obj_assignment[obj_assignment < 0] = G - 1  # bg points to last gt # * will return a 1-D tensor, no object points are assigned G-1
        # * for each B, K, it has an one-hot vector, indicating whether it is object i or bg(G-1)
        obj_assignment_one_hot = torch.zeros((B, K, G)).to(seed_xyz.device)
        obj_assignment_one_hot.scatter_(2, obj_assignment.unsqueeze(-1), 1) # * B, K, G
    else:
        # * use multi-view image
        # * no need to distinguish bg points and fg points
        obj_assignment_one_hot = torch.ones((B, K, G), device=seed_xyz.device)  # B, K, G

    # Normalized distances of points and gt centroids
    # * For each seed point, find the distance to each gt center
    delta_xyz = seed_xyz.unsqueeze(2) - gt_center.unsqueeze(1)  # (B, K, G, 3)
    delta_xyz = delta_xyz / (gt_size.unsqueeze(1) + 1e-6)  # (B, K, G, 3)
    new_dist = torch.sum(delta_xyz ** 2, dim=-1)
    euclidean_dist1 = torch.sqrt(new_dist + 1e-6)  # BxKxG
    # * Each point K_i will only have the distance to the gt object it belongs to
    # * Other distances are set to 100
    euclidean_dist1 = (
        euclidean_dist1 * obj_assignment_one_hot
        + 100 * (1 - obj_assignment_one_hot)
    )  # BxKxG
    euclidean_dist1 = euclidean_dist1.transpose(1, 2).contiguous()  # BxGxK

    # Find the points that lie closest to each gt centroid
    topk_inds = (
        torch.topk(euclidean_dist1, topk, largest=False)[1]
        * box_label_mask[:, :, None]
        + (box_label_mask[:, :, None] - 1)
    )  # BxGxtopk
    topk_inds = topk_inds.long()  # BxGxtopk
    topk_inds = topk_inds.view(B, -1).contiguous()  # B, Gxtopk
    batch_inds = torch.arange(B)[:, None].repeat(1, G*topk).to(seed_xyz.device)
    batch_topk_inds = torch.stack([
        batch_inds,
        topk_inds
    ], -1).view(-1, 2).contiguous()

    # Topk points closest to each centroid are marked as true objects
    objectness_label = torch.zeros((B, K + 1)).long().to(seed_xyz.device)
    objectness_label[batch_topk_inds[:, 0], batch_topk_inds[:, 1]] = 1
    objectness_label = objectness_label[:, :K]
    if not end_points['use_multiview_image']:
        objectness_label_mask = torch.gather(point_instance_label, 1, seed_inds) # * only consider forground points, bg points are all negative
        objectness_label[objectness_label_mask < 0] = 0

    # Compute objectness loss
    criterion = SigmoidFocalClassificationLoss()
    cls_weights = (objectness_label >= 0).float()
    cls_normalizer = cls_weights.sum(dim=1, keepdim=True).float()
    cls_weights /= torch.clamp(cls_normalizer, min=1.0)
    cls_loss_src = criterion(
        seeds_obj_cls_logits.view(B, K, 1),
        objectness_label.unsqueeze(-1),
        weights=cls_weights
    )
    objectness_loss = cls_loss_src.sum() / B

    return objectness_loss
ayushjain1144 commented 1 year ago

Hi, I cannot think of anything butd-detr specific that might be causing this issue. Some things that come to mind to investigate this further are: a) Maybe Visualize the unprojected pointcloud and make sure the ground truth boxes overlayed on unprojected xyzs make sense b) Try some overfitting experiments where you train on some 10 scenes and evaluate on the same 10. If there are no bugs, you should be able to get 100% accuracy (you could utilise --debug flag for this).

Hiusam commented 1 year ago

Dear Ayush Jain,

I tried your original implementation (point cloud) and found that it's also difficult to overfit. I use 10 samples and the model gets the accuracy like this after 430 epochs. I do butd_cls = True.

[05/24 22:41:07 debug]: last_ Box given span (soft-token) Acc: 20.0
[05/24 22:41:07 debug]: last_ Box given span (contrastive) Acc: 10.0
[05/24 22:41:07 debug]: proposal_ Box given span (soft-token) Acc: 25.0
[05/24 22:41:07 debug]: proposal_ Box given span (contrastive) Acc: 35.0
[05/24 22:41:07 debug]: 0head_ Box given span (soft-token) Acc: 10.0
[05/24 22:41:07 debug]: 0head_ Box given span (contrastive) Acc: 20.0
[05/24 22:41:07 debug]: 1head_ Box given span (soft-token) Acc: 30.0
[05/24 22:41:07 debug]: 1head_ Box given span (contrastive) Acc: 20.0
[05/24 22:41:07 debug]: 2head_ Box given span (soft-token) Acc: 20.0
[05/24 22:41:07 debug]: 2head_ Box given span (contrastive) Acc: 25.0
[05/24 22:41:07 debug]: 3head_ Box given span (soft-token) Acc: 10.0
[05/24 22:41:07 debug]: 3head_ Box given span (contrastive) Acc: 20.0
[05/24 22:41:07 debug]: 4head_ Box given span (soft-token) Acc: 15.0
[05/24 22:41:07 debug]: 4head_ Box given span (contrastive) Acc: 10.0

Did something go wrong?

Best, Hiusam.

ayushjain1144 commented 1 year ago

Ah okay, I think 10 could be too few and you might need to play around with the batch size/learning rate. Maybe just try the original one i.e. 128 for both original pointcloud setting and your multiview setting.

Hiusam commented 1 year ago

Hi Ayush Jain,

I found that for the original point cloud, it can overfit with a 90% accuracy with 48 samples. But when it comes to only 1 sample, the accuracy is always 0. Why can the model not overfit to only a few samples?

Best, Hiusam

ayushjain1144 commented 1 year ago

With one sample, i would expect the model to overfit, but maybe not with original hyperparamters. I think you might need to try a lower learning rate to make it work.

Let us know what you find!

Hiusam commented 1 year ago

Hi, I cannot think of anything butd-detr specific that might be causing this issue. Some things that come to mind to investigate this further are: a) Maybe Visualize the unprojected pointcloud and make sure the ground truth boxes overlayed on unprojected xyzs make sense b) Try some overfitting experiments where you train on some 10 scenes and evaluate on the same 10. If there are no bugs, you should be able to get 100% accuracy (you could utilise --debug flag for this).

Dear Ayush,

My multiview image setting cannot overfit. Do you think there is a bug or do I need to adjust my hyperparameters?

ayushjain1144 commented 1 year ago

It's hard to say with certainty but most likely it should be a bug if you have already tried to play a bit with hyperparameters.

Hiusam commented 1 year ago

Dear Ayush,

Thank you so much for your reply. But why did the hyperparameters play an essential role when overfitting? It should be easy for a neural net to overfit only some samples with most hyperparameters settings.

Thanks.

ayushjain1144 commented 1 year ago

It should be easy for a neural net to overfit only some samples with most hyperparameters settings.

I am not sure that's accurate, if you have a high learning rate the model will diverge, and if you have a very low learning rate and local minima in your learning landscape, you might get stuck there -- irrespective of the number of samples you have.

I suspect that with very few samples in the overfitting experiment, the gradients are noisier and thus either we need to tune the parameters again (lower lr i would expect) or just use a reasonably big overfitting set (like 128 scenes or the 48 samples you tried). It might also be possible that there could be a more optimal learning rate which works across both settings. So I think for your debugging, I would use 128 samples with the original learning rate and try to debug until the model overfits.

If you find something more, please let us know!