tsunghan-wu / RandLA-Net-pytorch

:four_leaf_clover: Pytorch Implementation of RandLA-Net (https://arxiv.org/abs/1911.11236)
MIT License
122 stars 33 forks source link

RuntimeError: weight tensor should be defined either for all 19 classes or no classes but got weight tensor of shape: [1, 19] #26

Open pomeloooo opened 1 year ago

pomeloooo commented 1 year ago

RuntimeError: weight tensor should be defined either for all 19 classes or no classes but got weight tensor of shape: [1, 19] how to solve

magical1998 commented 1 year ago

I have same error,have u solve it?

magical1998 commented 1 year ago

weight‘s shape should be (19,),not (1,19) in train_SemantciKITTI.py,line 94

    class_weights = torch.from_numpy(train_dataset.get_class_weight()).float().cuda()
    class_weights=class_weights.squeeze(0)
    self.criterion = nn.CrossEntropyLoss(weight=class_weights, reduction='none')
Wansit99 commented 1 year ago

weight‘s shape should be (19,),not (1,19) in train_SemantciKITTI.py,line 94

    class_weights = torch.from_numpy(train_dataset.get_class_weight()).float().cuda()
    class_weights=class_weights.squeeze(0)
    self.criterion = nn.CrossEntropyLoss(weight=class_weights, reduction='none')
    class_weights = torch.from_numpy(train_dataset.get_class_weight()).float().view(19).cuda() # change code to this can run