isl-org / Open3D-ML

An extension of Open3D to address 3D Machine Learning tasks
Other
1.86k stars 319 forks source link

Error on training Toronto3d On RandlaNet #634

Open 7Vyshak7 opened 10 months ago

7Vyshak7 commented 10 months ago

I am training the toronto3d dataset on the Randlanet model.But on training I am getting this error

RuntimeError: weight tensor should be defined either for all 8 classes or no classes but got weight tensor of shape: [1, 8]

I have not changed anything from the code and config file. What Could be the problem I will attach the config file and training code

Config

dataset: name: Toronto3D cache_dir: ./logs/cache class_weights: [41697357, 1745448, 6572572, 19136493, 674897, 897825, 4634634, 374721] ignored_label_inds:

Training Code

import os
import open3d.ml as _ml3d
import open3d.ml.torch as ml3d

cfg_file = "Config/randlanet_toronto3d.yml"
cfg = _ml3d.utils.Config.load_from_file(cfg_file)

dataset = ml3d.datasets.Toronto3D(dataset_path='./TrainingData/Toronto3d/',**cfg.dataset)
model = ml3d.models.RandLANet(**cfg.model)

pipeline = ml3d.pipelines.SemanticSegmentation(model=model, dataset=dataset,**cfg.pipeline)
pipeline.run_train()
msanov commented 10 months ago

Hi, you are encountering this problem because there is a bug in the original code located in ./ml3d/torch/modules/losses/semseg_loss.py. The issue arises because the code expects a tensor with 8 values (num of valid classes), but you are passing a list of tensors instead. You should use .squeeze() to fix this. Here is my corrected code (line 40):" `

def __init__(self, pipeline, model, dataset, device):
    super(SemSegLoss, self).__init__()
    # weighted_CrossEntropyLoss
    if 'class_weights' in dataset.cfg.keys() and len(
            dataset.cfg.class_weights) != 0:
        class_wt = DataProcessing.get_class_weights(
            dataset.cfg.class_weights)
        weights = torch.tensor(class_wt, dtype=torch.float, device=device)
        weights = weights.squeeze()
        self.weighted_CrossEntropyLoss = nn.CrossEntropyLoss(weight=weights)
    else:
        self.weighted_CrossEntropyLoss = nn.CrossEntropyLoss()

`