klintan / pytorch-lanenet

LaneNet implementation in PyTorch
MIT License
216 stars 60 forks source link

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1) #17

Closed ayushmankumar7 closed 3 years ago

ayushmankumar7 commented 3 years ago

Command : python lanenet/train.py --dataset ./data/training_data_example

Error: Traceback (most recent call last): File "lanenet/train.py", line 156, in main() File "lanenet/train.py", line 144, in main train_iou = train(train_loader, model, optimizer, epoch) File "lanenet/train.py", line 68, in train total_loss, binary_loss, instance_loss, out, train_iou = compute_loss(net_output, binary_label, instance_label) File "C:\Users\ayush\AppData\Local\Programs\Python\Python37\lib\site-packages\lanenet-0.1.0-py3.7.egg\lanenet\model\model.py", line 75, in compute_loss File "C:\Users\ayush\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\nn\modules\module.py", line 550, in call result = self.forward(*input, **kwargs) File "C:\Users\ayush\AppData\Local\Programs\Python\Python37\lib\site-packages\lanenet-0.1.0-py3.7.egg\lanenet\model\loss.py", line 33, in forward File "C:\Users\ayush\AppData\Local\Programs\Python\Python37\lib\site-packages\lanenet-0.1.0-py3.7.egg\lanenet\model\loss.py", line 71, in _discriminative_loss File "C:\Users\ayush\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\functional.py", line 916, in norm return _VF.frobenius_norm(input, _dim, keepdim=keepdim) IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

klintan commented 3 years ago

I think it's because of this latest commit. I'll try to revert back. I didn't double check it, and it was said to be in the paper

https://github.com/klintan/pytorch-lanenet/commit/db9f116ba3f42dbfabf064e4a89ec068e9da4ee4

so if you change that back to what it was it should not fail. However I'll have to dig into the paper again to see what the main issue might be.

mummy2358 commented 3 years ago

@klintan @ayushmankumar7 Hello guys I happened to face the same issue and managed to fix it. It might be a help :)

I went through the whole pipeline and the issue seems to come from the definition of function "_discriminative_loss" in the file “ loss.py ”: at line 66 around ( I use lines of "#" to mark it )

    for b in range(batch_size):
        embedding_b = embedding[b]  # (embed_dim, H, W)
        seg_gt_b = seg_gt[b]

        labels = torch.unique(seg_gt_b)
        labels = labels[labels != 0]
        num_lanes = len(labels)
        if num_lanes == 0:
            # please refer to issue here: https://github.com/harryhan618/LaneNet/issues/12
            _nonsense = embedding.sum()
            _zero = torch.zeros_like(_nonsense)
            var_loss = var_loss + _nonsense * _zero
            dist_loss = dist_loss + _nonsense * _zero
            reg_loss = reg_loss + _nonsense * _zero
            continue

        centroid_mean = []
        for lane_idx in labels:
            seg_mask_i = (seg_gt_b == lane_idx)
            if not seg_mask_i.any():
                continue
            ############   **PROBLEM IS HERE !!!!!!**   ############

            # embedding_i = embedding_b[seg_mask_i]
            # shapes:  embedding_b: [5, 256, 512] ;  seg_mask_i: [5, 256, 512].   
            # After indexing it becomes the gathering of masked positive cells of "embedding_b", which is a single dimensional vector. So the subsequent operation  "torch.norm( xxx, dim=1  )" of var_loss throw the exception "DImension out of range"
            # I change it to element wise multiplication and it works :)
            embedding_i = embedding_b * seg_mask_i

            #################################################

            mean_i = torch.mean(embedding_i, dim=0)
            centroid_mean.append(mean_i)

            # ---------- var_loss -------------
            var_loss = var_loss + torch.mean(F.relu(
                torch.norm(embedding_i - mean_i, dim=1) - self.delta_var) ** 2) / num_lanes
        centroid_mean = torch.stack(centroid_mean)  # (n_lane, embed_dim)