xichenpan / ARLDM

Official Pytorch Implementation of Synthesizing Coherent Story with Auto-Regressive Latent Diffusion Models
https://arxiv.org/abs/2211.10950
MIT License
182 stars 28 forks source link

ckpt is not saved after training? #6

Closed kriskrisliu closed 1 year ago

kriskrisliu commented 1 year ago

I ran the training process with config file as following. Everything looked well during training. However, when the training end, I found no ckpt file in ckpt_dir. Did I miss anything?

# device
mode: train  # train sample
gpu_ids: [ 0,1,2,3 ]  # gpu ids
batch_size: 1  # batch size each item denotes one story
num_workers: 16  # number of workers
num_cpu_cores: -1  # number of cpu cores
seed: 0  # random seed
ckpt_dir: ./result/flintstones # checkpoint directory
run_name: 5epoch_visualization # name for this run

# task
dataset: flintstones  # pororo flintstones vistsis vistdii
task: visualization  # continuation visualization

# train
init_lr: 1e-5  # initial learning rate
warmup_epochs: 1  # warmup epochs
max_epochs: 5  # max epochs
train_model_file:  # model file for resume, none for train from scratch
freeze_clip: False #True  # whether to freeze clip
freeze_blip: False  # whether to freeze blip
freeze_resnet: False  # whether to freeze resnet

# sample
# test_model_file:  # model file for test
# calculate_fid: True  # whether to calculate FID scores
# scheduler: ddim  # ddim pndm
# guidance_scale: 6  # guidance scale
# num_inference_steps: 250  # number of inference steps
# sample_output_dir: /path/to/save_samples # output directory

# pororo:
#   hdf5_file: /path/to/pororo.h5
#   max_length: 85
#   new_tokens: [ "pororo", "loopy", "eddy", "harry", "poby", "tongtong", "crong", "rody", "petty" ]
#   clip_embedding_tokens: 49416
#   blip_embedding_tokens: 30530

flintstones:
  hdf5_file: /root/autodl-tmp/dataset/flintstones.h5
  max_length: 91
  new_tokens: [ "fred", "barney", "wilma", "betty", "pebbles", "dino", "slate" ]
  clip_embedding_tokens: 49412
  blip_embedding_tokens: 30525

# vistsis:
#   hdf5_file: /path/to/vist.h5
#   max_length: 100
#   clip_embedding_tokens: 49408
#   blip_embedding_tokens: 30524

# vistdii:
#   hdf5_file: /path/to/vist.h5
#   max_length: 65
#   clip_embedding_tokens: 49408
#   blip_embedding_tokens: 30524

hydra:
  run:
    dir: .
  output_subdir: null
hydra/job_logging: disabled
hydra/hydra_logging: disabled
xichenpan commented 1 year ago

Hi, my sincere apologies for this bug. The checkpointing was disabled when we test our immigrated code (cause the model size is too large, saving needs time and storage). You can modify the code as: https://github.com/Flash-321/ARLDM/blob/f44277744517041ac9a955794c4f3a5f73d59eb9/main.py#L403 And it should only save the last checkpoint.

You can also customize checkpoint saving behavior according to: https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.ModelCheckpoint.html

Thanks a lot for pointing out this bug!

kriskrisliu commented 1 year ago

Fortunately, I only took a 5-epoch training. Thanks for reply~ BTW, I train the model with 4x A100 and find that it takes ~6 hours per epoch. Does it sound OK? How long does it take for a whole training (let's say 50-epoch with 8x A100) ?

xichenpan commented 1 year ago

I seems a little bit slow. I trained the model in a 8 A100 node, and it spends 2-3 days to finish the trainnig. I suggest you checking if the 4 A100 GPUs are on a same node. What's more do check if the backward time is much longer than the forward time (usually the time should be the same), you can easily setup a profiler to check the time cost for each component (see this https://pytorch-lightning.readthedocs.io/en/1.6.4/advanced/profiler.html). And It seems the num_workers is set to 16 in your config. For me, it is much slower that 4, maybe you can test this setting in your device.