facebookresearch / mmf

A modular framework for vision & language multimodal research from Facebook AI Research (FAIR)
https://mmf.sh/
Other
5.45k stars 925 forks source link

Can't reproduce UniT results with gradient accumulation #1251

Open RulinShao opened 2 years ago

RulinShao commented 2 years ago

Instructions To Reproduce the Issue:

Hi, thanks for the codes! I tried to reproduce the UniT vqa2 single task training example as given in the doc: https://mmf.sh/docs/projects/unit/ The default setting uses a batch size of 64 with 64 GPUs. I want to reproduce the same result with 8 GPUs combined with gradient accumulation by setting the update frequency to 8. My script:

python mmf_cli/run.py \
    config=projects/unit/configs/vqa2/single_task.yaml \
    datasets=vqa2 \
    model=unit run_type=train \
    env.save_dir=./save/unit/vqa2_single_task/batch64 \
    distributed.world_size=8 distributed.port=20000 \
    training.batch_size=8 \
    training.update_frequency=8 \

However, after setting the training.update_frequency to 8 or 4, NaN appeared in the loss and the training terminated. The log is as below:

......
2022-06-18T10:22:10 | INFO | mmf.trainers.callbacks.logistics : progress: 15700/150000, train/vqa2/loss_0: 3.8779, train/vqa2/loss_0/avg: 33.4238, train/total_loss: 3.8779, train/total_loss/avg: 33.4238, max mem: 8548.0, experiment: run, epoch: 1, num_updates: 15700, iterations: 125600, max_updates: 150000, lr: 0.00005, ups: 0.42, time: 04m 201ms, time_since_start: 10h 30m 57s 070ms, eta: 119h 06m 28s 144ms
2022-06-18T10:26:11 | INFO | mmf.trainers.callbacks.logistics : progress: 15800/150000, train/vqa2/loss_0: 3.8779, train/vqa2/loss_0/avg: 33.2336, train/total_loss: 3.8779, train/total_loss/avg: 33.2336, max mem: 8548.0, experiment: run, epoch: 1, num_updates: 15800, iterations: 126400, max_updates: 150000, lr: 0.00005, ups: 0.42, time: 04m 887ms, time_since_start: 10h 34m 57s 957ms, eta: 119h 21m 31s 117ms
2022-06-18T10:30:15 | INFO | mmf.trainers.callbacks.logistics : progress: 15900/150000, train/vqa2/loss_0: 3.8779, train/vqa2/loss_0/avg: 33.0513, train/total_loss: 3.8779, train/total_loss/avg: 33.0513, max mem: 8548.0, experiment: run, epoch: 1, num_updates: 15900, iterations: 127200, max_updates: 150000, lr: 0.00005, ups: 0.41, time: 04m 04s 049ms, time_since_start: 10h 39m 02s 006ms, eta: 120h 50m 07s 210ms
2022-06-18T10:34:15 | INFO | mmf.trainers.callbacks.logistics : progress: 16000/150000, train/vqa2/loss_0: 3.8779, train/vqa2/loss_0/avg: 32.8726, train/total_loss: 3.8779, train/total_loss/avg: 32.8726, max mem: 8548.0, experiment: run, epoch: 1, num_updates: 16000, iterations: 128000, max_updates: 150000, lr: 0.00005, ups: 0.42, time: 03m 59s 910ms, time_since_start: 10h 43m 01s 917ms, eta: 118h 41m 51s 159ms
2022-06-18T10:38:14 | INFO | mmf.trainers.callbacks.logistics : progress: 16100/150000, train/vqa2/loss_0: 3.8203, train/vqa2/loss_0/avg: 32.6894, train/total_loss: 3.8203, train/total_loss/avg: 32.6894, max mem: 8548.0, experiment: run, epoch: 1, num_updates: 16100, iterations: 128800, max_updates: 150000, lr: 0.00005, ups: 0.42, time: 03m 59s 091ms, time_since_start: 10h 47m 01s 008ms, eta: 118h 12m 15s 089ms
2022-06-18T10:42:13 | INFO | mmf.trainers.callbacks.logistics : progress: 16200/150000, train/vqa2/loss_0: 3.8203, train/vqa2/loss_0/avg: 32.5049, train/total_loss: 3.8203, train/total_loss/avg: 32.5049, max mem: 8548.0, experiment: run, epoch: 1, num_updates: 16200, iterations: 129600, max_updates: 150000, lr: 0.00005, ups: 0.42, time: 03m 58s 903ms, time_since_start: 10h 50m 59s 912ms, eta: 118h 01m 22s 668ms
2022-06-18T10:46:14 | INFO | mmf.trainers.callbacks.logistics : progress: 16300/150000, train/vqa2/loss_0: 3.8203, train/vqa2/loss_0/avg: 32.3293, train/total_loss: 3.8203, train/total_loss/avg: 32.3293, max mem: 8548.0, experiment: run, epoch: 1, num_updates: 16300, iterations: 130400, max_updates: 150000, lr: 0.00005, ups: 0.42, time: 04m 999ms, time_since_start: 10h 55m 911ms, eta: 118h 58m 09s 610ms
2022-06-18T10:50:16 | INFO | mmf.trainers.callbacks.logistics : progress: 16400/150000, train/vqa2/loss_0: 3.8203, train/vqa2/loss_0/avg: 32.1515, train/total_loss: 3.8203, train/total_loss/avg: 32.1515, max mem: 8548.0, experiment: run, epoch: 1, num_updates: 16400, iterations: 131200, max_updates: 150000, lr: 0.00005, ups: 0.41, time: 04m 02s 179ms, time_since_start: 10h 59m 03s 091ms, eta: 119h 27m 44s 608ms
2022-06-18T10:54:16 | INFO | mmf.trainers.callbacks.logistics : progress: 16500/150000, train/vqa2/loss_0: 3.5749, train/vqa2/loss_0/avg: 31.9783, train/total_loss: 3.5749, train/total_loss/avg: 31.9783, max mem: 8548.0, experiment: run, epoch: 1, num_updates: 16500, iterations: 132000, max_updates: 150000, lr: 0.00005, ups: 0.42, time: 03m 59s 992ms, time_since_start: 11h 03m 03s 083ms, eta: 118h 17m 41s 825ms
2022-06-18T10:58:16 | INFO | mmf.trainers.callbacks.logistics : progress: 16600/150000, train/vqa2/loss_0: 3.5749, train/vqa2/loss_0/avg: 31.8103, train/total_loss: 3.5749, train/total_loss/avg: 31.8103, max mem: 8548.0, experiment: run, epoch: 1, num_updates: 16600, iterations: 132800, max_updates: 150000, lr: 0.00005, ups: 0.42, time: 03m 59s 788ms, time_since_start: 11h 07m 02s 871ms, eta: 118h 06m 21s 111ms
2022-06-18T11:02:14 | INFO | mmf.trainers.callbacks.logistics : progress: 16700/150000, train/vqa2/loss_0: 3.5749, train/vqa2/loss_0/avg: 31.6493, train/total_loss: 3.5749, train/total_loss/avg: 31.6493, max mem: 8548.0, experiment: run, epoch: 1, num_updates: 16700, iterations: 133600, max_updates: 150000, lr: 0.00005, ups: 0.42, time: 03m 58s 312ms, time_since_start: 11h 11m 01s 184ms, eta: 117h 17m 28s 390ms
2022-06-18T11:06:12 | INFO | mmf.trainers.callbacks.logistics : progress: 16800/150000, train/vqa2/loss_0: 3.8080, train/vqa2/loss_0/avg: 31.4836, train/total_loss: 3.8080, train/total_loss/avg: 31.4836, max mem: 8548.0, experiment: run, epoch: 1, num_updates: 16800, iterations: 134400, max_updates: 150000, lr: 0.00005, ups: 0.42, time: 03m 57s 870ms, time_since_start: 11h 14m 59s 054ms, eta: 116h 59m 08s 055ms
2022-06-18T11:10:10 | INFO | mmf.trainers.callbacks.logistics : progress: 16900/150000, train/vqa2/loss_0: 3.8080, train/vqa2/loss_0/avg: 31.3207, train/total_loss: 3.8080, train/total_loss/avg: 31.3207, max mem: 8548.0, experiment: run, epoch: 1, num_updates: 16900, iterations: 135200, max_updates: 150000, lr: 0.00005, ups: 0.42, time: 03m 57s 935ms, time_since_start: 11h 18m 56s 990ms, eta: 116h 55m 47s 129ms
2022-06-18T11:12:15 | WARNING | py.warnings : /home/ubuntu/anaconda3/envs/mmf/lib/python3.7/site-packages/torch/nn/utils/clip_grad.py:55: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.
  FutureWarning, stacklevel=2)

2022-06-18T11:12:15 | WARNING | py.warnings : /home/ubuntu/anaconda3/envs/mmf/lib/python3.7/site-packages/torch/nn/utils/clip_grad.py:55: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.
  FutureWarning, stacklevel=2)

2022-06-18T11:12:15 | INFO | mmf.trainers.core.training_loop : NaN occurred in the following loss(es): train/vqa2/loss_0; exiting the training

Expected behavior:

It should finish the training normally as get the expected results as in the UniT paper Table 1 line 1.

Environment:

I use a p3.16 aws instance with 8 V100s of 16Gb memory. My environment is built strictly following the MMF installation instructions:

# Create new env
conda create -n mmf python=3.7
conda activate mmf
cd mmf
pip install -r requirements.txt
RulinShao commented 2 years ago

I could load the saved checkpoint and resume training, the NaN doesn't seem to appear in the same iteration, instead, it appears every 16900 iterations. I.e., I resumed the training from the checkpoint saved at 10000th iteration, it reported NaN at 26900th iterations instead. Any insight on this?