adymaharana / storydalle

MIT License
328 stars 27 forks source link

OOM issues #5

Open andlyu opened 2 years ago

andlyu commented 2 years ago

Hey, thanks for interesting work. I was trying to run train_story.sh, but was running into memory issues when running on NVIDIA V100. Would you be able to share the configurations that you ran it on, and if there are ways to decrease the GPU memory requirements?

adymaharana commented 2 years ago

Hi, thanks for your interest in our work! We trained the model on a single A6000 which has 48GB of memory. I am sorry that this model has such voracious memory requirements. I am also trying to reduce the memory load for inference on a smaller GPU for public demo purposes. My suggestions are to keep the VQGAN on CPU and pre-extract image embeddings to prevent it from being a training speed bottleneck. I will let you know if I am able to find better methods to run it on smaller GPUs.

julkaztwittera commented 1 year ago

Hi, I am also trying to run Story-DALL-E using NVIDIA V100 which has 32 GB. Even inference does not work, I get such an error: RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling cublasGemmStridedBatchedExFix( handle, opa, opb, m, n, k, (void*)(&falpha), a, CUDA_R_16F, lda, stridea, b, CUDA_R_16F, ldb, strideb, (void*)(&fbeta), c, CUDA_R_16F, ldc, stridec, num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP) I am extremely fond of your work and would like to use it for further purposes, so please let me know if you managed to solve this issue.

adymaharana commented 1 year ago

Hi,

I am glad to hear that this work is of interest to you! I have finally figured out how to run inference in less than 32 GB. The key is to perform full precision inference in Stage 1 i.e. VQGAN and mixed precision inference in Stage 2 i.e. the autoregressive decoder. The codebase has been updated with the changes and you should be able to infer using NVIDIA V100 now. Feel free to report in this thread if you face any further problems. Thanks!

KyonP commented 1 year ago

I am trying to train the code on A100 (80GB VRAM) and it keeps failing due to OOM.

if [ "$1" = "pororo" ]; then
  echo "Training on Pororo"
  DATA_DIR=../../datasets/pororo_data_512/
  OUTPUT_ROOT=./save/pororo
  SENT_EMBED=512
  STORY_LEN=4
  LR=1e-4
  TRAIN_BS=1
  GRAD_ACC=4
elif [ "$1" = "flintstones" ]; then
  echo "Training on Flintstones"
  DATA_DIR=../data/flintstones
  OUTPUT_ROOT=./out/flintstones
  SENT_EMBED=512
  STORY_LEN=4
  LR=1e-5
  TRAIN_BS=1
  GRAD_ACC=4
elif [ "$1" = "didemo" ]; then
  echo "Training on DiDeMo"
  DATA_DIR=../data/didemo
  OUTPUT_ROOT=./out/didemo
  SENT_EMBED=512
  STORY_LEN=2
  TRAIN_BS=1
  GRAD_ACC=8
fi

LOG_DIR=../runs/

python ./train_t2i.py \
--prefix_model_name_or_path './1.3B/' \
--tuning_mode story \
--dataset_name $1 \
--preseqlen 32 \
--condition \
--story_len $STORY_LEN \
--sent_embed $SENT_EMBED \
--prefix_dropout 0.2 \
--data_dir $DATA_DIR \
--dataloader_num_workers 4 \
--output_dir $OUTPUT_ROOT \
--log_dir $LOG_DIR \
--do_train --do_eval \
--per_gpu_train_batch_size $TRAIN_BS \
--per_gpu_eval_batch_size 1 \
--num_train_epochs 50 \
--gradient_accumulation_steps $GRAD_ACC \
--learning_rate $LR \
--logging_steps 50 \
--eval_steps 500 \
--generate_steps 1000 \
--overwrite_output_dir

I haven't changed your code. However, giving me the following OOM error.

Training on Pororo
Global seed set to 42
Initializing the Conditional Dalle model
Setting up Cross-attention Layers
Total parameters is 1396275075
[2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35, 38, 41]
Loading models from checkpoint ./1.3B/
./1.3B/stage1_last.ckpt successfully restored..
./1.3B/stage2_last.ckpt succesfully restored..
path :  ./1.3B/tokenizer/bpe-16k-vocab.json
./1.3B/tokenizer successfully restored..
Training dataset size: %s 10191
Validation dataset size: %s 2334
03/25/2023 12:53:04 - WARNING - __main__ -   Process rank: -1, device: cuda, n_gpu: 8, distributed training: False, 16-bits training: False
Maximum optimizer steps : 63693
Moving model to CUDA
Cross-attention layers are in cuda: True
Training:   0%|                                                                                                                                                                                        | 0/1273 [00:00<?, ?it/s]Train Epoch: 0 [0/10191 (0%)]   Loss: 7.173273
torch.return_types.topk(
values=tensor([0.03, 0.03, 0.02, 0.02, 0.02, 0.02, 0.01, 0.01, 0.01, 0.01],
       device='cuda:0'),
indices=tensor([ 9834,  7738,   852, 14283,  9570,  7927,  5370,  1363,  2102, 14059],
       device='cuda:0')) tensor(5238, device='cuda:0')
Training:   0%|▏                                                                                                                                                                             | 1/1273 [00:15<5:22:19, 15.20s/it]
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/my/storydalle/story-dalle/./train_t2i.py:561 in <module>                               │
│                                                                                                  │
│   558 │                                                                                          │
│   559 │   args = parser.parse_args()                                                             │
│   560 │                                                                                          │
│ ❱ 561 │   main(args)                                                                             │
│                                                                                                  │
│ /home/my/storydalle/story-dalle/./train_t2i.py:403 in main                                   │
│                                                                                                  │
│   400 │   │   │   │   sent_embeds = batch[3]                                                     │
│   401 │   │   │   │   src_images = src_images.to(device)                                         │
│   402 │   │   │   │   sent_embeds = sent_embeds.to(device)                                       │
│ ❱ 403 │   │   │   │   logits_img, logits_txt, codes = model(images, src_images, texts, sent_em   │
│   404 │   │   │   else:                                                                          │
│   405 │   │   │   │   logits_img, logits_txt, codes = model(images, texts)                       │
│   406                                                                                            │
│                                                                                                  │
│ /root/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py:1130 in          │
│ _call_impl                                                                                       │
│                                                                                                  │
│   1127 │   │   # this function, and just call forward.                                           │
│   1128 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1131 │   │   # Do not call functions when jit is used                                          │
│   1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /home/my/storydalle/story-dalle/dalle/models/__init__.py:1041 in forward                     │
│                                                                                                  │
│   1038 │   │   # pos_enc_code = pos_enc_code.unsqueeze(-1)                                       │
│   1039 │   │   # print(images.shape, codes.shape, texts.shape)                                   │
│   1040 │   │   if self.config.story.condition:                                                   │
│ ❱ 1041 │   │   │   logits_img, logits_txt = self.stage2.forward_with_context(codes, texts,       │
│   1042 │   │   │   │   │   │   │   │   │   │   │   │   │   │   │   │   │     pos_enc_code, pos_  │
│   1043 │   │   │   │   │   │   │   │   │   │   │   │   │   │   │   │   │     self.cross_attenti  │
│   1044 │   │   │   │   │   │   │   │   │   │   │   │   │   │   │   │   │     prompt=prompt, pos  │
│                                                                                                  │
│ /home/my/storydalle/story-dalle/dalle/models/stage2/transformer.py:212 in                    │
│ forward_with_context                                                                             │
│                                                                                                  │
│   209 │   │   │   if i in cross_attention_idxs:                                                  │
│   210 │   │   │   │   x, _ = block.sample_with_context(x, src_images, mask, cross_attention_la   │
│   211 │   │   │   else:                                                                          │
│ ❱ 212 │   │   │   │   x, _ = block.sample(x, layer_past=None if past is None else past[i])       │
│   213 │   │                                                                                      │
│   214 │   │   x = self.ln_f(x)                                                                   │
│   215                                                                                            │
│                                                                                                  │
│ /home/my/storydalle/story-dalle/dalle/models/stage2/layers.py:179 in sample                  │
│                                                                                                  │
│   176 │   def sample(self, x, layer_past=None):                                                  │
│   177 │   │   attn, present = self.attn(self.ln1(x), use_cache=True, layer_past=layer_past)      │
│   178 │   │   x = x + attn                                                                       │
│ ❱ 179 │   │   x = x + self.mlp(self.ln2(x))                                                      │
│   180 │   │   return x, present                                                                  │
│   181 │                                                                                          │
│   182 │   def sample_with_context(self, x, context, context_mask, cross_attn_layer, layer_past   │
│                                                                                                  │
│ /root/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py:1130 in          │
│ _call_impl                                                                                       │
│                                                                                                  │
│   1127 │   │   # this function, and just call forward.                                           │
│   1128 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1131 │   │   # Do not call functions when jit is used                                          │
│   1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /root/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/container.py:139 in        │
│ forward                                                                                          │
│                                                                                                  │
│   136 │   # with Any as TorchScript expects a more precise type                                  │
│   137 │   def forward(self, input):                                                              │
│   138 │   │   for module in self:                                                                │
│ ❱ 139 │   │   │   input = module(input)                                                          │
│   140 │   │   return input                                                                       │
│   141 │                                                                                          │
│   142 │   def append(self, module: Module) -> 'Sequential':                                      │
│                                                                                                  │
│ /root/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py:1130 in          │
│ _call_impl                                                                                       │
│                                                                                                  │
│   1127 │   │   # this function, and just call forward.                                           │
│   1128 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                                         │
│   1131 │   │   # Do not call functions when jit is used                                          │
│   1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:                                │
│                                                                                                  │
│ /root/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/linear.py:114 in forward   │
│                                                                                                  │
│   111 │   │   │   init.uniform_(self.bias, -bound, bound)                                        │
│   112 │                                                                                          │
│   113 │   def forward(self, input: Tensor) -> Tensor:                                            │
│ ❱ 114 │   │   return F.linear(input, self.weight, self.bias)                                     │
│   115 │                                                                                          │
│   116 │   def extra_repr(self) -> str:                                                           │
│   117 │   │   return 'in_features={}, out_features={}, bias={}'.format(                          │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: CUDA out of memory. Tried to allocate 242.00 MiB (GPU 0; 79.21 GiB total capacity; 75.25 GiB already allocated; 201.62 MiB free; 77.42 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory
try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

I am not sure why this is happening. I hope for suggestions