xichenpan / ARLDM

Official Pytorch Implementation of Synthesizing Coherent Story with Auto-Regressive Latent Diffusion Models
https://arxiv.org/abs/2211.10950
MIT License
182 stars 28 forks source link

Training stucks at the beginning.. #4

Open kriskrisliu opened 1 year ago

kriskrisliu commented 1 year ago

I'm working on training this model on the FlintstonesSV dataset. I run the training script on a GPU server with 8x 3080ti (with 12GB ram each card). Is this server able to train this model? What's the maximun memory useage during training?

The training process seems to stuck at "trainer.fit(model, dataloader, ckpt_path=args.train_model_file)". Here is the log:

Global seed set to 0
clip 4 new tokens added
blip 1 new tokens added
clip 4 new tokens added
blip 1 new tokens added
load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Global seed set to 0
clip 4 new tokens added
blip 1 new tokens added
Global seed set to 0
clip 4 new tokens added
blip 1 new tokens added
clip 4 new tokens added
blip 1 new tokens added
clip 4 new tokens added
blip 1 new tokens added
Global seed set to 0
Global seed set to 0
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/4
clip 4 new tokens added
blip 1 new tokens added
clip 4 new tokens added
blip 1 new tokens added
load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth
load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth
Global seed set to 0
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/4
[2022-12-26 18:54:48,402][torch.distributed.distributed_c10d][INFO] - Added key: store_based_barrier_key:1 to store for rank: 1
load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth
Global seed set to 0
Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/4
[2022-12-26 18:54:51,472][torch.distributed.distributed_c10d][INFO] - Added key: store_based_barrier_key:1 to store for rank: 2
Global seed set to 0
Initializing distributed: GLOBAL_RANK: 3, MEMBER: 4/4
[2022-12-26 18:54:55,093][torch.distributed.distributed_c10d][INFO] - Added key: store_based_barrier_key:1 to store for rank: 3
[2022-12-26 18:54:55,097][torch.distributed.distributed_c10d][INFO] - Added key: store_based_barrier_key:1 to store for rank: 0
[2022-12-26 18:54:55,098][torch.distributed.distributed_c10d][INFO] - Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 4 processes
----------------------------------------------------------------------------------------------------

[2022-12-26 18:54:55,103][torch.distributed.distributed_c10d][INFO] - Rank 3: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.
[2022-12-26 18:54:55,106][torch.distributed.distributed_c10d][INFO] - Rank 1: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.
[2022-12-26 18:54:55,106][torch.distributed.distributed_c10d][INFO] - Rank 2: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.

no signal after waiting for 30 min...

The config.yaml is:

# device
mode: train  # train sample
gpu_ids: [ 0, 1, 2, 3 ]  # gpu ids
batch_size: 1  # batch size each item denotes one story
num_workers: 4  # number of workers
num_cpu_cores: -1  # number of cpu cores
seed: 0  # random seed
ckpt_dir: results/ # checkpoint directory
run_name: first_try # name for this run

# task
dataset: flintstones  # pororo flintstones vistsis vistdii
task: visualization  # continuation visualization

# train
init_lr: 1e-5  # initial learning rate
warmup_epochs: 1  # warmup epochs
max_epochs: 50  # max epochs
train_model_file:  # model file for resume, none for train from scratch
freeze_clip: False  # whether to freeze clip
freeze_blip: False  # whether to freeze blip
freeze_resnet: False  # whether to freeze resnet

# # sample
# test_model_file:  # model file for test
# calculate_fid: True  # whether to calculate FID scores
# scheduler: ddim  # ddim pndm
# guidance_scale: 6  # guidance scale
# num_inference_steps: 250  # number of inference steps
# sample_output_dir: /path/to/save_samples # output directory

# pororo:
#   hdf5_file: /path/to/pororo.h5
#   max_length: 85
#   new_tokens: [ "pororo", "loopy", "eddy", "harry", "poby", "tongtong", "crong", "rody", "petty" ]
#   clip_embedding_tokens: 49416
#   blip_embedding_tokens: 30530

flintstones:
  hdf5_file: Downloads/save_hdf5_files/flintstones.hdf5
  max_length: 91
  new_tokens: [ "fred", "barney", "wilma", "betty", "pebbles", "dino", "slate" ]
  clip_embedding_tokens: 49412
  blip_embedding_tokens: 30525

# vistsis:
#   hdf5_file: /path/to/vist.h5
#   max_length: 100
#   clip_embedding_tokens: 49408
#   blip_embedding_tokens: 30524

# vistdii:
#   hdf5_file: /path/to/vist.h5
#   max_length: 65
#   clip_embedding_tokens: 49408
#   blip_embedding_tokens: 30524

hydra:
  run:
    dir: .
  output_subdir: null
hydra/job_logging: disabled
hydra/hydra_logging: disabled
xichenpan commented 1 year ago

It seems to be a pytorch lightning problem. We trained the model using batch_size=1 on A100 GPUs with 80G vram (usuing around 70+GB). Gradient checkpointing, amp, and 8-bit optimizer can greatly reduce the varm requirement, You can also set freeze_clip=True, freeze_blip=True, reeze_resnet=True to reduce vram usage, see this https://github.com/huggingface/diffusers/tree/main/examples/dreambooth and https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html. But it is still not possible to run the model on 3080ti, I guess it would be able to be trained on V100 (32GB) or even 3090 (24GB) after above modification.

kriskrisliu commented 1 year ago

Nice! Well, gradient checkpointing, amp, and 8-bit optimizer seem to be an optional choice, I'll take a try or maybe use more powerfull GPUs instead... Actually, I'm curious about freezing clip, blip and resent, which means to freeze most parameters. Is it able to produce comparable results?

xichenpan commented 1 year ago

Yes, we try to freezing clip, blip and resent in our very early experiments and the performance is still acceptable, but we do not run whole experiment use this setting and test FID scores. btw, for acceptable performance, the model do not need to be trained for 50 epochs, 3-5 epochs are enough, this can also reduce your computational power cost.

bibisbar commented 1 year ago

Yes, we try to freezing clip, blip and resent in our very early experiments and the performance is still acceptable, but we do not run whole experiment use this setting and test FID scores.

Hi, I want to know how long it will take approximately if I set all three to freeze and train on 4 A100? Really appreciate for this great open source project:)

xichenpan commented 1 year ago

Yes, we try to freezing clip, blip and resent in our very early experiments and the performance is still acceptable, but we do not run whole experiment use this setting and test FID scores.

Hi, I want to know how long it will take approximately if I set all three to freeze and train on 4 A100? Really appreciate for this great open source project:)

@bibisbar Hi, for unfreeze setting, the forward time for a batch=1 in a single A100 GPU is 0.5s, and freeze will not change this time much. As for the backward, the original time is 0.5s too, and freeze the gradient will accelerate this process, I guess it may reduce 50% time cost at most. So freeze operation will slightly shorten the training time, but it can reduce the memory usage (which I think is more important, like param efficient tuning)

skywalker00001 commented 1 year ago

It seems to be a pytorch lightning problem. We trained the model using batch_size=1 on A100 GPUs with 80G vram (usuing around 70+GB). Gradient checkpointing, amp, and 8-bit optimizer can greatly reduce the varm requirement, You can also set freeze_clip=True, freeze_blip=True, reeze_resnet=True to reduce vram usage, see this https://github.com/huggingface/diffusers/tree/main/examples/dreambooth and https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html. But it is still not possible to run the model on 3080ti, I guess it would be able to be trained on V100 (32GB) or even 3090 (24GB) after above modification.

Hi, but How to enable gradient checkpointing in pytorch lightning model? I think in the huggingface model, it's easy to model.enable_gradient_checkpointing(), but it seems no to work for ARLDM model....

xichenpan commented 1 year ago

@skywalker00001 Hi, I am sorry it seems pytorch lighting do not support this setting. https://github.com/Lightning-AI/lightning/issues/49

skywalker00001 commented 1 year ago

@skywalker00001 Hi, I am sorry it seems pytorch lighting do not support this setting. Lightning-AI/lightning#49

Thanks. And the other several approaches (freezing resnet, clip embedding, blip embedding), amp and 8-bit optimizer together helped reduce the vRAM to about 40GB on my A6000 for batch_size = 1 successfully.

xichenpan commented 1 year ago

@skywalker00001 Great!

FlamingJay commented 1 year ago

It seems to be a pytorch lightning problem. We trained the model using batch_size=1 on A100 GPUs with 80G vram (usuing around 70+GB). Gradient checkpointing, amp, and 8-bit optimizer can greatly reduce the varm requirement, You can also set freeze_clip=True, freeze_blip=True, reeze_resnet=True to reduce vram usage, see this https://github.com/huggingface/diffusers/tree/main/examples/dreambooth and https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html. But it is still not possible to run the model on 3080ti, I guess it would be able to be trained on V100 (32GB) or even 3090 (24GB) after above modification.

freeze_clip=True, freeze_blip=True, reeze_resnet=True and V100 doesn't work, still cuda out of momory

skywalker00001 commented 1 year ago

您好,我是侯翼,您的邮件已收到,祝您生活愉快~

Echo411 commented 1 year ago

@skywalker00001 Hi, I am sorry it seems pytorch lighting do not support this setting. Lightning-AI/lightning#49

Thanks. And the other several approaches (freezing resnet, clip embedding, blip embedding), amp and 8-bit optimizer together helped reduce the vRAM to about 40GB on my A6000 for batch_size = 1 successfully.

Hi, I'm a beginner of deep learning and I would appreciate it if you could tell me how you can use amp and 8-bit optimizer to reduce VRAM usage. And I wonder, can I run this project on two 24G 3090s after optimization? Looking forward to hearing from you.