NVIDIA / pix2pixHD

Synthesizing and manipulating 2048x1024 images with conditional GANs
https://tcwang0509.github.io/pix2pixHD/
Other
6.64k stars 1.39k forks source link

CUDA assertion error binary_cross_entropy loss #9

Open blancaag opened 6 years ago

blancaag commented 6 years ago

A CUDA assertion error pops up when setting --no_lsgan. It seems it's because there are negative values thrown into the nn.BCELoss(). Get's fixed applying nn.BCEWithLogitsLoss() instead.

(...)
/opt/conda/conda-bld/pytorch_1512386481460/work/torch/lib/THCUNN/BCECriterion.cu:30: Acctype bce_functor<Dtype, Acctype>::operator()(Tuple) [with Tuple = thrust::tuple<float, float, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type>, Dtype = float, Acctype = float]: block: [16,0,0], thread: [31,0,0] Assertion `input >= 0. && input <= 1.` failed.
CUDA error after cudaEventDestroy in future dtor: device-side assert triggeredTraceback (most recent call last):
  File "train.py", line 56, in <module>
    Variable(data['image']), Variable(data['feat']), infer=save_fake)
  File "/root/miniconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 325, in __call__
    result = self.forward(*input, **kwargs)
  File "/root/miniconda3/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 66, in forward
    return self.module(*inputs[0], **kwargs[0])
  File "/root/miniconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 325, in __call__
    result = self.forward(*input, **kwargs)
  File "/blanca/project/wip/pix2pixHD-master/models/pix2pixHD_model.py", line 158, in forward
    loss_D_fake = self.criterionGAN(pred_fake_pool, False)
  File "/blanca/project/wip/pix2pixHD-master/models/networks.py", line 110, in __call__
    loss += self.loss(pred, target_tensor)
  File "/root/miniconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 325, in __call__
    result = self.forward(*input, **kwargs)
  File "/root/miniconda3/lib/python3.6/site-packages/torch/nn/modules/loss.py", line 372, in forward
    size_average=self.size_average)
  File "/root/miniconda3/lib/python3.6/site-packages/torch/nn/functional.py", line 1179, in binary_cross_entropy
    return torch._C._nn.binary_cross_entropy(input, target, weight, size_average)
RuntimeError: cudaEventSynchronize in future::wait: device-side assert triggered
THCudaCheck FAIL file=/opt/conda/conda-bld/pytorch_1512386481460/work/torch/lib/THC/generic/THCStorage.c line=184 error=59 : device-side assert triggered
terminate called after throwing an instance of 'std::runtime_error'
  what():  cuda runtime error (59) : device-side assert triggered at /opt/conda/conda-bld/pytorch_1512386481460/work/torch/lib/THC/generic/THCStorage.c:184
Aborted (core dumped)
Tord-Zhang commented 6 years ago

@blancaag I have met the same problem? How did you fix it?

blancaag commented 6 years ago

@mangdian I mention it above. Get's fixed applying nn.BCEWithLogitsLoss() instead of nn.BCELoss() in networks.py line 82 --it restricts loss values between 0 and 1 before applying the loss.

aviel08 commented 6 years ago

I think I'm having the same issue but only when I use my own dataset. I've tried nn.BCEWithLogitsLoss() but with no luck. It must be related to my data but I can't figure out what I must be missing.

RuntimeError: CUDNN_STATUS_INTERNAL_ERROR /opt/conda/conda-bld/pytorch_1525812548180/work/aten/src/THC/THCTensorScatterGather.cu:176: void THCudaTensor_scatterFillKernel(TensorInfo<Real, IndexType>, TensorInfo<long, IndexType>, Real, int, IndexType) [with IndexType = unsigned int, Real = float, Dims = -1]: block: [88,0,0], thread: [346,0,0] AssertionindexValue >= 0 && indexValue < tensor.sizes[dim]failed. /opt/conda/conda-bld/pytorch_1525812548180/work/aten/src/THC/THCTensorScatterGather.cu:176: void THCudaTensor_scatterFillKernel(TensorInfo<Real, IndexType>, TensorInfo<long, IndexType>, Real, int, IndexType) [with IndexType = unsigned int, Real = float, Dims = -1]: block: [88,0,0], thread: [347,0,0] AssertionindexValue >= 0 && indexValue < tensor.sizes[dim]failed. /opt/conda/conda-bld/pytorch_1525812548180/work/aten/src/THC/THCTensorScatterGather.cu:176: void THCudaTensor_scatterFillKernel(TensorInfo<Real, IndexType>, TensorInfo<long, IndexType>, Real, int, IndexType) [with IndexType = unsigned int, Real = float, Dims = -1]: block: [88,0,0], thread: [348,0,0] AssertionindexValue >= 0 && indexValue < tensor.sizes[dim]failed. /opt/conda/conda-bld/pytorch_1525812548180/work/aten/src/THC/THCTensorScatterGather.cu:176: void THCudaTensor_scatterFillKernel(TensorInfo<Real, IndexType>, TensorInfo<long, IndexType>, Real, int, IndexType) [with IndexType = unsigned int, Real = float, Dims = -1]: block: [88,0,0], thread: [349,0,0] AssertionindexValue >= 0 && indexValue < tensor.sizes[dim]failed. /opt/conda/conda-bld/pytorch_1525812548180/work/aten/src/THC/THCTensorScatterGather.cu:176: void THCudaTensor_scatterFillKernel(TensorInfo<Real, IndexType>, TensorInfo<long, IndexType>, Real, int, IndexType) [with IndexType = unsigned int, Real = float, Dims = -1]: block: [88,0,0], thread: [350,0,0] AssertionindexValue >= 0 && indexValue < tensor.sizes[dim]failed. /opt/conda/conda-bld/pytorch_1525812548180/work/aten/src/THC/THCTensorScatterGather.cu:176: void THCudaTensor_scatterFillKernel(TensorInfo<Real, IndexType>, TensorInfo<long, IndexType>, Real, int, IndexType) [with IndexType = unsigned int, Real = float, Dims = -1]: block: [88,0,0], thread: [351,0,0] AssertionindexValue >= 0 && indexValue < tensor.sizes[dim]failed. THCudaCheck FAIL file=/opt/conda/conda-bld/pytorch_1525812548180/work/aten/src/THC/generic/THCStorage.c line=184 error=59 : device-side assert triggered terminate called after throwing an instance of 'std::runtime_error' what(): cuda runtime error (59) : device-side assert triggered at /opt/conda/conda-bld/pytorch_1525812548180/work/aten/src/THC/generic/THCStorage.c:184 Aborted (core dumped)

blancaag commented 6 years ago

@aviel08 - I think it's a different error and not in the BCELoss - "AssertionindexValue >= 0 && indexValue < tensor.sizes[dim]". I'd suggest to start printing the shape of the input tensors after this line: https://github.com/NVIDIA/pix2pixHD/blob/20687df85d30e6fff5aafb29b7981923da9fd02f/train.py#L51

On 15 Aug 2018, at 08:37, Alex Leiva notifications@github.com wrote:

ndexValue < tensor.sizes[dim]failed. /opt/conda/conda-

ZhangXiaoying0116 commented 6 years ago

@aviel08 I met the same problem,how did you solve it

hfarazi commented 5 years ago

You can use torch.clamp(0,1) after your sigmoid layer

relh commented 5 years ago

I had to also add:

x = torch.where(torch.isnan(x), torch.zeros_like(x), x) x = torch.where(torch.isinf(x), torch.zeros_like(x), x)

tongpinmo commented 5 years ago

I have applyed nn.BCEWithLogitsLoss() instead of BECLoss(),solve it

izuna385 commented 4 years ago

I find that @relh's solution is effective.

>>> torch.nn.functional.sigmoid(torch.tensor(float('nan')))
tensor(nan)

x = torch.where(torch.isnan(x), torch.zeros_like(x), x) prevents this error. Thanks a lot!