lucidrains / imagen-pytorch

Implementation of Imagen, Google's Text-to-Image Neural Network, in Pytorch
MIT License
8.11k stars 768 forks source link

CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling `cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)` #284

Closed cyrilzakka closed 1 year ago

cyrilzakka commented 1 year ago

When trying to encode text using conditional diffusion in the Dataset class:

def __getitem__(self, index):
        path = self.paths[index]
        text = self.texts[index]

        # Randomly sample 3 items from text list
        if self.num_sentences and self.num_sentences < len(text.split('. ')):
            text = '. '.join(random.sample(text.split('. '), self.num_sentences))

        texts, masks = t5.t5_encode_text([text], return_attn_mask = True)
        tensor = avi_to_tensor(str(path), self.transform)
        if self.conditional:
            return (self.cast_num_frames_fn(tensor), texts[0], masks[0])
        else:
            return self.cast_num_frames_fn(tensor)

I receive the following error: RuntimeError: CUDA error: CUBLAS_STATUS_INVALID_VALUE when callingcublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)``

Here's the stack trace:

Traceback (most recent call last):
  File "/scratch/users/czakka/imagen-pytorch/engine_train.py", line 77, in <module>
    loss = trainer.train_step(unet_number = ARGS.unet, max_batch_size = ARGS.batch // 2)
  File "/scratch/users/czakka/imagen-pytorch/imagen-pytorch/trainer.py", line 602, in train_step
    loss = self.step_with_dl_iter(self.train_dl_iter, unet_number = unet_number, **kwargs)
  File "/scratch/users/czakka/imagen-pytorch/imagen-pytorch/trainer.py", line 618, in step_with_dl_iter
    dl_tuple_output = cast_tuple(next(dl_iter))
  File "/scratch/users/czakka/imagen-pytorch/imagen-pytorch/data.py", line 29, in cycle
    for data in dl:
  File "/home/users/czakka/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 628, in __next__
    data = self._next_data()
  File "/home/users/czakka/.local/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 671, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/home/users/czakka/.local/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 58, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/users/czakka/.local/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 58, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/users/czakka/.local/lib/python3.9/site-packages/torch/utils/data/dataset.py", line 295, in __getitem__
    return self.dataset[self.indices[idx]]
  File "/scratch/users/czakka/imagen-pytorch/imagen-pytorch/data.py", line 149, in __getitem__
    texts, masks = t5.t5_encode_text([text], return_attn_mask = True)
  File "/scratch/users/czakka/imagen-pytorch/imagen-pytorch/t5.py", line 113, in t5_encode_text
    encoded_text = t5_encode_tokenized_text(token_ids, attn_mask = attn_mask, name = name)
  File "/scratch/users/czakka/imagen-pytorch/imagen-pytorch/t5.py", line 99, in t5_encode_tokenized_text
    output = t5(input_ids = token_ids, attention_mask = attn_mask)
  File "/home/users/czakka/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/users/czakka/.local/lib/python3.9/site-packages/transformers/models/t5/modeling_t5.py", line 1846, in forward
    encoder_outputs = self.encoder(
  File "/home/users/czakka/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/users/czakka/.local/lib/python3.9/site-packages/transformers/models/t5/modeling_t5.py", line 1040, in forward
    layer_outputs = layer_module(
  File "/home/users/czakka/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/users/czakka/.local/lib/python3.9/site-packages/transformers/models/t5/modeling_t5.py", line 673, in forward
    self_attention_outputs = self.layer[0](
  File "/home/users/czakka/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/users/czakka/.local/lib/python3.9/site-packages/transformers/models/t5/modeling_t5.py", line 579, in forward
    attention_output = self.SelfAttention(
  File "/home/users/czakka/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/users/czakka/.local/lib/python3.9/site-packages/transformers/models/t5/modeling_t5.py", line 498, in forward
    query_states = shape(self.q(hidden_states))  # (batch_size, n_heads, seq_length, dim_per_head)
  File "/home/users/czakka/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/users/czakka/.local/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling `cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)`
cxchhh commented 1 year ago

so how did you solve it? Can you share it?