lucidrains / DALLE-pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch
MIT License
5.56k stars 644 forks source link

generate + fp16 (deepspeed) #261

Open rom1504 opened 3 years ago

rom1504 commented 3 years ago

issue about the fact generate is not possible with fp16 (deepspeed) introduced when fp16 feature was introduced https://github.com/lucidrains/DALLE-pytorch/pull/157 :

/pytorch/aten/src/THC/THCTensorIndex.cu:218: indexSelectSmallIndex: block: [0,0,0], thread: [120,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:218: indexSelectSmallIndex: block: [0,0,0], thread: [121,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:218: indexSelectSmallIndex: block: [0,0,0], thread: [122,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:218: indexSelectSmallIndex: block: [0,0,0], thread: [123,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:218: indexSelectSmallIndex: block: [0,0,0], thread: [124,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:218: indexSelectSmallIndex: block: [0,0,0], thread: [125,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:218: indexSelectSmallIndex: block: [0,0,0], thread: [126,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/THC/THCTensorIndex.cu:218: indexSelectSmallIndex: block: [0,0,0], thread: [127,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
Traceback (most recent call last):
  File "train_dalle.py", line 469, in <module>
    image = dalle.generate_images(text[:1], filter_thres=0.9).float()  # topk sampling at 0.9
  File "DALLE-pytorch.env/lib64/python3.6/site-packages/torch/autograd/grad_mode.py", line 15, in decorate_context
    return func(*args, **kwargs)
  File "DALLE-pytorchdalle_pytorch/dalle_pytorch.py", line 42, in inner
    out = fn(model, *args, **kwargs)
  File "DALLE-pytorchdalle_pytorch/dalle_pytorch.py", line 425, in generate_images
    logits = self(text, image, mask = mask)[:, -1, :]
  File "DALLE-pytorch.env/lib64/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "DALLE-pytorchdalle_pytorch/dalle_pytorch.py", line 497, in forward
    out = self.transformer(tokens)
  File "DALLE-pytorch.env/lib64/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "DALLE-pytorchdalle_pytorch/transformer.py", line 132, in forward
    return self.layers(x, **kwargs)
  File "DALLE-pytorch.env/lib64/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "DALLE-pytorchdalle_pytorch/transformer.py", line 44, in forward
    return self.fn(x, **kwargs) * self.scale
  File "DALLE-pytorch.env/lib64/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "DALLE-pytorchdalle_pytorch/transformer.py", line 53, in forward
    return self.fn(self.norm(x), **kwargs)
  File "DALLE-pytorch.env/lib64/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "DALLE-pytorchdalle_pytorch/transformer.py", line 71, in forward
    return self.net(x)
  File "DALLE-pytorch.env/lib64/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "DALLE-pytorch.env/lib64/python3.6/site-packages/torch/nn/modules/container.py", line 117, in forward
    input = module(input)
  File "DALLE-pytorch.env/lib64/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "DALLE-pytorch.env/lib64/python3.6/site-packages/torch/nn/modules/linear.py", line 91, in forward
    return F.linear(input, self.weight, self.bias)
  File "DALLE-pytorch.env/lib64/python3.6/site-packages/torch/nn/functional.py", line 1676, in linear
    output = input.matmul(weight.t())
RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasGemmEx( handle, opa, opb, m, n, k, &falpha, a, CUDA_R_16F, lda, b, CUDA_R_16F, ldb, &fbeta, c, CUDA_R_16F, ldc, CUDA_R_32F, CUBLAS
_GEMM_DFALT_TENSOR_OP)`

that's the exact error. (issue also mentionned in #256 )

I'm looking into it

Any idea on the topic is welcome

janEbert commented 3 years ago

Another way to get the same error: Wrap your model initialization in deepspeed.zero.Init. Which is also the reason why #222 is still WIP – it encounters the same issue. See this comment which explains when this happens in the generation.

I think this is a DeepSpeed issue or us using the API not conforming to their idea. I asked about API usage for our case in this issue but haven't gotten a response sadly. I also tried keeping the VAE completely separate from the DALLE model, passing it in as a parameter instead but this hasn't helped either.