valeoai / Maskgit-pytorch

MIT License
145 stars 15 forks source link

questions about two stage training #9

Closed xyzhang626 closed 6 months ago

xyzhang626 commented 6 months ago

Hey @llvictorll and team,

Really appreciate your reproducing and open source it! It's really helpful for the community. I want to further understand the training and fine-tuning strategy mentioned in the tech report Sec.2. Is that meaning the first stage training is for 256256 and the second fine-tuning is for 512512?

It would be very helpful if you can kindly explain it more.

llvictorll commented 6 months ago

Hello!

Yes, exactly. For the second stage, we initialize the model with the weight of the model trained on ImageNet 256 and "fine tune" for an additional 750,000 iterations on the higher resolution (i.e., 512). Because of the higher resolution, we had to reduce the batch size (from 512 to 128), but the rest of the training remains the same.

Best,

Victor

VoyageWang commented 6 months ago

I am also finetuning the model you provided. I want to know the training transformer loss of the cross entropy in the second stage. I trained for 90,000 interactions, but it seems that the loss was floating around 1.3, which is not so good. I ask for your advice on whether should I wait for around 750k integrations like yours and what's your loss in the training set in the same phase as me. I appreciate your response and this excellent work!

xyzhang626 commented 6 months ago

Hello!

Yes, exactly. For the second stage, we initialize the model with the weight of the model trained on ImageNet 256 and "fine tune" for an additional 750,000 iterations on the higher resolution (i.e., 512). Because of the higher resolution, we had to reduce the batch size (from 512 to 128), but the rest of the training remains the same.

Best,

Victor

Appreciate it!

YAOYI626 commented 6 months ago

I am also finetuning the model you provided. I want to know the training transformer loss of the cross entropy in the second stage. I trained for 90,000 interactions, but it seems that the loss was floating around 1.3, which is not so good. I ask for your advice on whether should I wait for around 750k integrations like yours and what's your loss in the training set in the same phase as me. I appreciate your response and this excellent work!

hey @VoyageWang Did you train with multi-node? I think one thing worth more notice is the learning rate. This repo does not auto scale lr when scaling up batch size, so you might need to adjust that manually to find optimal one.

llvictorll commented 6 months ago

I am also finetuning the model you provided. I want to know the training transformer loss of the cross entropy in the second stage. I trained for 90,000 interactions, but it seems that the loss was floating around 1.3, which is not so good. I ask for your advice on whether should I wait for around 750k integrations like yours and what's your loss in the training set in the same phase as me. I appreciate your response and this excellent work!

Hello @VoyageWang, sorry for the delay; last week was intense, haha. It is a little bit difficult to say; it depends on the batch size, learning rate, and indeed, the nature of your images. In my experiments, I noticed that after approximately 100,000 steps with a batch size larger than 256 and a learning rate of 1e-4, you should already have acceptable results. However, 750k steps were necessary to achieve the best FID/IS. I would advise you to search for better hyperparameters (batch size, learning rate, model size) if you do not have the same data as ImageNet. Best, Victor.

xyzhang626 commented 6 months ago

hey @llvictorll sorry to bother again. Did you ever use this code for multi-node training? I try to train with 4 nodes total 32 V100, but the job seems as slow as 8 V100. And it shows following error at the end of one epoch:

  File "/mnt/Maskgit-pytorch/Trainer/trainer.py", line 226, in all_gather
    dist.all_gather_object(tensor_list, obj)
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/c10d_logger.py", line 47, in wrapper
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 2410, in all_gather_object
    object_list[i] = _tensor_to_object(tensor, tensor_size)

I profiled the forward and backward process. The time cost of backwards is pretty zip-zag like following figure.

image

Did you ever meet with it? Or any advice to debug it?

Your idea is very valuable to me. Thanks in advance!

llvictorll commented 6 months ago

Hello @xyzhang626,

Unfortunately, I don't have any experience with multi-node training, so I may not be able to provide much assistance. (Perhaps this tutorial could be helpful? https://pytorch.org/tutorials/intermediate/ddp_series_multinode.html)

The error seems to be related to the function that gather all loss values from the other machines (within the single node) for printing/logs purposes. This function may need to be adapted for multi-nodes.

Victor

xyzhang626 commented 6 months ago

@llvictorll thanks anyway! I'd close this issue.