ShichenLiu / CondenseNet

CondenseNet: Light weighted CNN for mobile devices
MIT License
694 stars 131 forks source link

Cuda runtime error ClassNLLCriterion assertion #10

Closed vponcelo closed 5 years ago

vponcelo commented 6 years ago

Hi,

I am facing the following problem when I attempt to train the network with:

python main.py --model condensenet -b 256 -j 26 person-reid/market1501 --stages 4-6-8-10-8 --growth 8-16-32-64-128 --gpu 0,1,2,3 --savedir person-reid/results_market1501-24kgen --resume
/opt/conda/conda-bld/pytorch_1503970438496/work/torch/lib/THCUNN/ClassNLLCriterion.cu:57: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int, long) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [0,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1503970438496/work/torch/lib/THCUNN/ClassNLLCriterion.cu:57: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int, long) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [1,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1503970438496/work/torch/lib/THCUNN/ClassNLLCriterion.cu:57: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int, long) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [2,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1503970438496/work/torch/lib/THCUNN/ClassNLLCriterion.cu:57: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int, long) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [3,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1503970438496/work/torch/lib/THCUNN/ClassNLLCriterion.cu:57: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int, long) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [4,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1503970438496/work/torch/lib/THCUNN/ClassNLLCriterion.cu:57: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int, long) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [9,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1503970438496/work/torch/lib/THCUNN/ClassNLLCriterion.cu:57: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int, long) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [10,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1503970438496/work/torch/lib/THCUNN/ClassNLLCriterion.cu:57: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int, long) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [11,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1503970438496/work/torch/lib/THCUNN/ClassNLLCriterion.cu:57: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int, long) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [12,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1503970438496/work/torch/lib/THCUNN/ClassNLLCriterion.cu:57: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int, long) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [22,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1503970438496/work/torch/lib/THCUNN/ClassNLLCriterion.cu:57: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int, long) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [23,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1503970438496/work/torch/lib/THCUNN/ClassNLLCriterion.cu:57: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int, long) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [24,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1503970438496/work/torch/lib/THCUNN/ClassNLLCriterion.cu:57: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int, long) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [25,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1503970438496/work/torch/lib/THCUNN/ClassNLLCriterion.cu:57: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int, long) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [26,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1503970438496/work/torch/lib/THCUNN/ClassNLLCriterion.cu:57: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int, long) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [29,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1503970438496/work/torch/lib/THCUNN/ClassNLLCriterion.cu:57: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int, long) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [30,0,0] Assertion `t >= 0 && t < n_classes` failed.
THCudaCheck FAIL file=/opt/conda/conda-bld/pytorch_1503970438496/work/torch/lib/THC/generic/THCStorage.c line=32 error=59 : device-side assert triggered
Traceback (most recent call last):
  File "main.py", line 480, in <module>
    main()
  File "main.py", line 239, in main
    train(train_loader, model, criterion, optimizer, epoch)
  File "main.py", line 314, in train
    prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
  File "main.py", line 474, in accuracy
    correct_k = correct[:k].view(-1).float().sum(0)
  File "/mnt/storage/home/vp17941/.conda/envs/condensenet/lib/python3.6/site-packages/torch/tensor.py", line 43, in float
    return self.type(type(self).__module__ + '.FloatTensor')
  File "/mnt/storage/home/vp17941/.conda/envs/condensenet/lib/python3.6/site-packages/torch/cuda/__init__.py", line 278, in type
    return super(_CudaBase, self).type(*args, **kwargs)
  File "/mnt/storage/home/vp17941/.conda/envs/condensenet/lib/python3.6/site-packages/torch/_utils.py", line 35, in _type
    return new_type(self.size()).copy_(self, async)
RuntimeError: cuda runtime error (59) : device-side assert triggered at /opt/conda/conda-bld/pytorch_1503970438496/work/torch/lib/THC/THCTensorCopy.cu:204
terminate called without an active exception
THCudaCheck FAIL file=/opt/conda/conda-bld/pytorch_1503970438496/work/torch/lib/THC/THCTensorCopy.cu line=204 error=59 : device-side assert triggered

[...]

The assertion error also occurs in the line 362:

Traceback (most recent call last):
  File "main.py", line 480, in <module>
    main()
  File "main.py", line 242, in main
    val_prec1, val_prec5 = validate(val_loader, model, criterion)
  File "main.py", line 362, in validate
    losses.update(loss.data[0], input.size(0))
RuntimeError: cuda runtime error (59) : device-side assert triggered at /opt/conda/conda-bld/pytorch_1503970438496/work/torch/lib/THC/generic/THCStorage.c:32
THCudaCheckWarn FAIL file=/opt/conda/conda-bld/pytorch_1503970438496/work/torch/lib/THC/THCStream.cpp line=50 error=29 : driver shutting down
THCudaCheckWarn FAIL file=/opt/conda/conda-bld/pytorch_1503970438496/work/torch/lib/THC/THCStream.cpp line=50 error=29 : driver shutting down

Training and testing images are 64x128, and I also tried by resizing only the training images to 256x256.

It seems to be caused by an inconsistency with the number of classes that I am trying to figure out. In the evaluation, it might occur that there are no samples for some of the test classes, which I noticed it can be problematic for your network if the directory classes do not match properly. A successful solution I have found for this in another dataset I am working is to create the same set of class-directories with exactly the same name both in training train and testing val partitions, leaving empty those class-directories where there are no samples for that class in testing. In this dataset, however, I get that error which is a bit confusing to me.

Another question I have is whether your network can be used to classify images of classes that exist in the test partition but not in the train partition. For instance, in a dataset where half of the classes are used for training and the other half for testing.

I would appreciate any comment if you have any clue about what might be causing that error and the last question about the classes.

Thanks a lot

ShichenLiu commented 6 years ago

Hi,

Yes you are right. The code (largely borrow from pytorch official examples) decides the class order by walking through the dataset directories, so the train/test folders should be exactly aligned. As for your problem, I would suggest you carefully inspecting the "target" with pdb to check if the class labels are non-negative integers and correct, i.e. smaller than the number of total classes.

As for your another question, I would suggest you referring to "Zero-shot classification" problems. I think it's a different setting where you should combine other approaches with our architectures.