meetps / pytorch-semseg

Semantic Segmentation Architectures Implemented in PyTorch
https://meetshah.dev/semantic-segmentation/deep-learning/pytorch/visdom/2017/06/01/semantic-segmentation-over-the-years.html
MIT License
3.38k stars 799 forks source link

SegNet on Pascal:: TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; #245

Closed farnaznouraei closed 3 years ago

farnaznouraei commented 3 years ago

Hi. I am trying to train SegNet on a small subset of pascal VOC 2012. I made another data loader for pascal to overfit on it (since SegNet gave me very bad results on the full VOC dataset)- the only change I made is I made it read a shorter version of train.txt, val.txt and trainval.txt. I get the following error message after starting the training:

train_reduced.py:223: YAMLLoadWarning: calling yaml.load() without Loader=... is deprecated, as the default Loader is unsafe. Please read https://msg.pyyaml.org/load for full details. cfg = yaml.load(fp) RUNDIR: runs/config_segnet./2000 sbd_path = /home/ubuntu/data/VOC/benchmark_RELEASE expected number of files: 30 sbd_path = /home/ubuntu/data/VOC/benchmark_RELEASE expected number of files: 30 image names: 2007_000250 image names: 2007_000039 image names: 2007_000063 image names: 2007_000068 image names: 2007_000243 image names: 2007_000032 image names: 2007_000121 image names: 2007_000170 image names: 2007_000256 Traceback (most recent call last): File "train_reduced.py", line 234, in image names: 2007_000241 train(cfg, writer, logger) File "train_reduced.py", line 123, in train for (images, labels) in trainloader: File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 345, in next data = self._next_data() File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 856, in _next_data return self._process_data(data) File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 881, in _process_data data.reraise() File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/_utils.py", line 395, in reraise raise self.exc_type(msg) TypeError: Caught TypeError in DataLoader worker process 0. Original Traceback (most recent call last): File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop data = fetcher.fetch(index) File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch return self.collate_fn(data) File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 79, in default_collate return [default_collate(samples) for samples in transposed] File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 79, in return [default_collate(samples) for samples in transposed] File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 81, in default_collate raise TypeError(default_collate_err_msg_format.format(elem_type)) TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.JpegImagePlugin.JpegImageFile'>

My config looks like below:

# Model Configuration
model:
  arch: segnet

# Data Configuration
data:
  dataset: pascal
  train_split: train
  val_split: val
  img_rows: 318
  img_cols: 500
  path: /home/ubuntu/data/VOC/VOCdevkit/VOC2012
  sbd_path: /home/ubuntu/data/VOC/benchmark_RELEASE/

# Training Configuration
 training:
    n_workers: 1
    train_iters: 2000
    batch_size: 6
    val_interval: 3000
    print_interval: 50
loss:
    name: 'cross_entropy'
    size_average: False

# Optmizer Configuration
optimizer:
    name: 'sgd'
    lr: 1.0e-5

l_rate: 1.0e-5
lr_schedule:
  name: constant_lr
momentum: 0.99
weight_decay: 0.0005

# LR Schedule Configuration
lr_schedule:

# Resume from checkpoint  
resume: null

Could someone please help me with this?

farnaznouraei commented 3 years ago

Resolved. I had set the is_transform flag off, in train.py during data loading.