isl-org / Open3D-ML

An extension of Open3D to address 3D Machine Learning tasks
Other
1.88k stars 321 forks source link

collate_fn problem for training with RandLANet #438

Closed EtienneBEGUIER closed 2 years ago

EtienneBEGUIER commented 2 years ago

Good morning,

I am trying to do training on a KITTI dataset using a RandLANet model. I am running the following code in a google Colab environment.


model = RandLANet(name='RandLANet', k_n=16, num_layers=4, num_points=45056, num_classes=19, ignored_label_inds=[0], sub_sampling_ratio=[4, 4, 4, 4],
                  dim_input=3, dim_feature=8, dim_output=[16, 64, 128, 256], grid_size=0.06,
                  batcher='DefaultBatcher', ckpt_path=None, weight_decay=0.0)

optimizer = {
  "adam_lr": 0.001,
  "betas": [0.95, 0.99],
  "weight_decay": 0.01,
  "scheduler_gamma":0.95
}
dataset = ml3d.datasets.KITTI(dataset_path='/content/Open3D-ML/examples/demo_data/KITTI', cache_dir='./logs/cache',training_split=['00', '01', '02', '03', '04', '05', '06', '07', '09', '10'])

pipeline = ObjectDetection(model=model, dataset=dataset, name='SemanticSegmentation',
    batch_size=4,
    val_batch_size=4,
    test_batch_size=3,
    max_epoch=100,
    learning_rate=1e-2,
    lr_decays=0.95,
    save_ckpt_freq=20,
    adam_lr=1e-2,
    scheduler_gamma=0.95,
    momentum=0.98,
    main_log_dir='./logs/',
    device='gpu',
    split='train',
    train_sum_dir='train_log',
    optimizer=optimizer)

pipeline.run_train()

I get the following error :

Exception                                 Traceback (most recent call last)
<ipython-input-51-9fa0b75b4819> in <module>()
     46     collate_fn=batcher1.collate_fn)
     47 
---> 48 pipeline.run_train()

6 frames
/usr/local/lib/python3.7/dist-packages/open3d/_ml3d/torch/pipelines/object_detection.py in run_train(self)
    348             # --------------------- validation
    349             if (epoch % cfg.get("validation_freq", 1)) == 0:
--> 350                 self.run_valid()
    351 
    352             self.save_logs(writer, epoch)

/usr/local/lib/python3.7/dist-packages/open3d/_ml3d/torch/pipelines/object_detection.py in run_valid(self, epoch)
    184         gt = []
    185         with torch.no_grad():
--> 186             for data in tqdm(valid_loader, desc='validation'):
    187                 data.to(device)
    188                 results = model(data)

/usr/local/lib/python3.7/dist-packages/tqdm/std.py in __iter__(self)
   1178 
   1179         try:
-> 1180             for obj in iterable:
   1181                 yield obj
   1182                 # Update and possibly print the progressbar.

/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    515             if self._sampler_iter is None:
    516                 self._reset()
--> 517             data = self._next_data()
    518             self._num_yielded += 1
    519             if self._dataset_kind == _DatasetKind.Iterable and \

/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
   1197             else:
   1198                 del self._task_info[idx]
-> 1199                 return self._process_data(data)
   1200 
   1201     def _try_put_index(self):

/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py in _process_data(self, data)
   1223         self._try_put_index()
   1224         if isinstance(data, ExceptionWrapper):
-> 1225             data.reraise()
   1226         return data
   1227 

/usr/local/lib/python3.7/dist-packages/torch/_utils.py in reraise(self)
    427             # have message field
    428             raise self.exc_type(message=msg)
--> 429         raise self.exc_type(msg)
    430 
    431 

Exception: Caught Exception in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/worker.py", line 202, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "/usr/local/lib/python3.7/dist-packages/open3d/_ml3d/torch/dataloaders/concat_batcher.py", line 565, in collate_fn
    f"Please define collate_fn for {self.model}, or use Default Batcher"
Exception: Please define collate_fn for RandLANet, or use Default Batcher

I haven't found a way to use the collate_fn function of the Default Batcher to resolve this problem. Is there a way to define a collate_fn function for an Open3D-ml model? Why does the model not use the default batcher?

sanskar107 commented 2 years ago

Hey, KITTI is a Object Detection dataset, whereas RandLANet is a model for semantic Segmentation. I think you meant to use SemanticKITTI.