mk-minchul / AdaFace

MIT License
625 stars 118 forks source link

Training AdaFace on WebFace42m OOM #72

Closed xxiMiaxx closed 1 year ago

xxiMiaxx commented 1 year ago

Dear @mk-minchul , thank you for this amazing work,

I have been trying to produce results from training adaface on WebFace42m with ResNet100, I'm using 8 * A100 (40 GB) Gpus, but I keep getting OOM (Out of memory) error even though i'm using ddp as the training strategy.

Training parameters:



Here it the training log


\AdaFace with the following property
self.m 0.4
self.h 0.333
self.s 64.0
self.t_alpha 0.01
Global seed set to 42
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
start training
classnum: 2059906
classnum: 2059906
classnum: 2059906
classnum: 2059906
classnum: 2059906
Global seed set to 42
initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/8
classnum: 2059906

\AdaFace with the following property
self.m 0.4
self.h 0.333
self.s 64.0
self.t_alpha 0.01
Global seed set to 42
start training
Global seed set to 42
initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/8
classnum: 2059906

\AdaFace with the following property
self.m 0.4
self.h 0.333
self.s 64.0
self.t_alpha 0.01
Global seed set to 42
start training
Global seed set to 42
initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/8

\AdaFace with the following property
self.m 0.4
self.h 0.333
self.s 64.0
self.t_alpha 0.01
Global seed set to 42
start training
Global seed set to 42
initializing distributed: GLOBAL_RANK: 3, MEMBER: 4/8

\AdaFace with the following property
self.m 0.4
self.h 0.333
self.s 64.0
self.t_alpha 0.01
Global seed set to 42
start training
Global seed set to 42
initializing distributed: GLOBAL_RANK: 4, MEMBER: 5/8

\AdaFace with the following property
self.m 0.4
self.h 0.333
self.s 64.0
self.t_alpha 0.01
Global seed set to 42
start training
Global seed set to 42
initializing distributed: GLOBAL_RANK: 6, MEMBER: 7/8

\AdaFace with the following property
self.m 0.4
self.h 0.333
self.s 64.0
self.t_alpha 0.01
Global seed set to 42
start training
Global seed set to 42
initializing distributed: GLOBAL_RANK: 7, MEMBER: 8/8

\AdaFace with the following property
self.m 0.4
self.h 0.333
self.s 64.0
self.t_alpha 0.01
Global seed set to 42
start training
Global seed set to 42
initializing distributed: GLOBAL_RANK: 5, MEMBER: 6/8
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 8 processes
----------------------------------------------------------------------------------------------------

creating train datasetcreating train datasetcreating train dataset

creating train datasetcreating train dataset

creating train dataset
creating train dataset
creating train dataset
record file length 42474558
record file length 42474558
record file length 42474558
record file length 42474558
record file length 42474558
record file length 42474558
record file length 42474558
record file length 42474558
creating val dataset
laoding validation data memfile
laoding validation data memfile
laoding validation data memfile
laoding validation data memfile
laoding validation data memfile
creating val dataset
laoding validation data memfile
laoding validation data memfile
laoding validation data memfile
laoding validation data memfile
laoding validation data memfile
creating val dataset
laoding validation data memfile
laoding validation data memfile
laoding validation data memfile
laoding validation data memfile
laoding validation data memfile
creating val dataset
laoding validation data memfile
laoding validation data memfile
laoding validation data memfile
laoding validation data memfile
laoding validation data memfile
creating val dataset
laoding validation data memfile
laoding validation data memfile
laoding validation data memfile
laoding validation data memfile
laoding validation data memfile
creating val dataset
laoding validation data memfile
laoding validation data memfile
laoding validation data memfile
laoding validation data memfile
laoding validation data memfile
creating val dataset
laoding validation data memfile
laoding validation data memfile
laoding validation data memfile
laoding validation data memfile
laoding validation data memfile
creating val dataset
laoding validation data memfile
laoding validation data memfile
laoding validation data memfile
laoding validation data memfile
laoding validation data memfile
LOCAL_RANK: 6 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 5 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 4 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 7 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
wandb: WARNING `resume` will be ignored since W&B syncing is set to `offline`. Starting a new run with run id 6pmwqb0w.
wandb: Tracking run with wandb version 0.13.5
wandb: W&B syncing is set to `offline` in this directory.  
wandb: Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.

  | Name               | Type             | Params
--------------------------------------------------------
0 | model              | Backbone         | 65.2 M
1 | head               | AdaFace          | 1.1 B 
2 | cross_entropy_loss | CrossEntropyLoss | 0     
--------------------------------------------------------
1.1 B     Trainable params
0         Non-trainable params
1.1 B     Total params
2,239.646 Total estimated model params size (MB)
/home/ma-user/anaconda/lib/python3.7/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:631: UserWarning: Checkpoint directory /home/ma-user/modelarts/outputs/output_0/experiments/webface42m_v2_20_epoch_workers8_bs32_wandb_8gpu_11-12_0 exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

Validation sanity check: 0it [00:00, ?it/s]
Validation sanity check:   0%|          | 0/100 [00:00<?, ?it/s]
Validation sanity check:   1%|          | 1/100 [00:04<07:10,  4.34s/it]
Validation sanity check:   3%|▎         | 3/100 [00:04<01:53,  1.17s/it]
Validation sanity check:   5%|▌         | 5/100 [00:04<00:57,  1.64it/s]
Validation sanity check:   7%|▋         | 7/100 [00:04<00:35,  2.63it/s]
Validation sanity check:   9%|▉         | 9/100 [00:04<00:23,  3.84it/s]
Validation sanity check:  11%|█         | 11/100 [00:05<00:16,  5.24it/s]
Validation sanity check:  13%|█▎        | 13/100 [00:05<00:12,  6.80it/s]
Validation sanity check:  15%|█▌        | 15/100 [00:05<00:10,  8.38it/s]
Validation sanity check:  17%|█▋        | 17/100 [00:05<00:08,  9.94it/s]
Validation sanity check:  19%|█▉        | 19/100 [00:05<00:07, 11.32it/s]
Validation sanity check:  21%|██        | 21/100 [00:05<00:06, 12.53it/s]
Validation sanity check:  23%|██▎       | 23/100 [00:05<00:05, 13.40it/s]
Validation sanity check:  25%|██▌       | 25/100 [00:05<00:05, 14.15it/s]
Validation sanity check:  27%|██▋       | 27/100 [00:05<00:04, 14.80it/s]
Validation sanity check:  29%|██▉       | 29/100 [00:06<00:04, 15.28it/s]
Validation sanity check:  31%|███       | 31/100 [00:06<00:04, 15.69it/s]
Validation sanity check:  33%|███▎      | 33/100 [00:06<00:04, 15.94it/s]
Validation sanity check:  35%|███▌      | 35/100 [00:06<00:04, 16.18it/s]
Validation sanity check:  37%|███▋      | 37/100 [00:06<00:03, 16.24it/s]
Validation sanity check:  39%|███▉      | 39/100 [00:06<00:03, 16.26it/s]
Validation sanity check:  41%|████      | 41/100 [00:06<00:03, 16.30it/s]
Validation sanity check:  43%|████▎     | 43/100 [00:06<00:03, 16.27it/s]
Validation sanity check:  45%|████▌     | 45/100 [00:07<00:03, 16.19it/s]
Validation sanity check:  47%|████▋     | 47/100 [00:07<00:03, 16.14it/s]
Validation sanity check:  49%|████▉     | 49/100 [00:07<00:03, 16.03it/s]
Validation sanity check:  51%|█████     | 51/100 [00:07<00:03, 16.16it/s]
Validation sanity check:  53%|█████▎    | 53/100 [00:07<00:02, 15.97it/s]
Validation sanity check:  55%|█████▌    | 55/100 [00:07<00:02, 16.06it/s]
Validation sanity check:  57%|█████▋    | 57/100 [00:07<00:02, 15.97it/s]
Validation sanity check:  59%|█████▉    | 59/100 [00:07<00:02, 16.18it/s]
Validation sanity check:  61%|██████    | 61/100 [00:08<00:02, 15.49it/s]
Validation sanity check:  63%|██████▎   | 63/100 [00:08<00:02, 15.79it/s]
Validation sanity check:  65%|██████▌   | 65/100 [00:08<00:02, 15.71it/s]
Validation sanity check:  67%|██████▋   | 67/100 [00:08<00:02, 15.87it/s]
Validation sanity check:  69%|██████▉   | 69/100 [00:08<00:01, 15.99it/s]
Validation sanity check:  71%|███████   | 71/100 [00:08<00:01, 15.81it/s]
Validation sanity check:  73%|███████▎  | 73/100 [00:08<00:01, 15.94it/s]
Validation sanity check:  75%|███████▌  | 75/100 [00:08<00:01, 15.98it/s]
Validation sanity check:  77%|███████▋  | 77/100 [00:09<00:01, 16.08it/s]
Validation sanity check:  79%|███████▉  | 79/100 [00:09<00:01, 15.46it/s]
Validation sanity check:  81%|████████  | 81/100 [00:09<00:01, 15.79it/s]
Validation sanity check:  83%|████████▎ | 83/100 [00:09<00:01, 16.04it/s]
Validation sanity check:  85%|████████▌ | 85/100 [00:09<00:00, 16.16it/s]
Validation sanity check:  87%|████████▋ | 87/100 [00:09<00:00, 16.26it/s]
Validation sanity check:  89%|████████▉ | 89/100 [00:09<00:00, 15.60it/s]
Validation sanity check:  91%|█████████ | 91/100 [00:09<00:00, 15.83it/s]
Validation sanity check:  93%|█████████▎| 93/100 [00:10<00:00, 16.04it/s]
Validation sanity check:  95%|█████████▌| 95/100 [00:10<00:00, 16.22it/s]
Validation sanity check:  97%|█████████▋| 97/100 [00:10<00:00, 16.35it/s]
Validation sanity check:  99%|█████████▉| 99/100 [00:10<00:00, 16.46it/s]/home/ma-user/anaconda/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:229: UserWarning: You called `self.log('agedb_30_num_val_samples', ...)` in your `validation_epoch_end` but the value needs to be floating point. Converting it to torch.float32.
  f"You called `self.log({self.meta.name!r}, ...)` in your `{self.meta.fx}` but the value needs to"
/home/ma-user/anaconda/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:229: UserWarning: You called `self.log('epoch', ...)` in your `validation_epoch_end` but the value needs to be floating point. Converting it to torch.float32.
  f"You called `self.log({self.meta.name!r}, ...)` in your `{self.meta.fx}` but the value needs to"

Global seed set to 42
Global seed set to 42
Global seed set to 42
Global seed set to 42
Global seed set to 42
Global seed set to 42
Global seed set to 42
Global seed set to 42

Training: 0it [00:00, ?it/s]
Training:   0%|          | 0/1329268 [00:00<?, ?it/s]
Epoch 0:   0%|          | 0/1329268 [00:00<?, ?it/s] time="2022-11-12T13:14:32+08:00" level=info msg="clean up child process succeed, pid=5985, wstatus=0, exit_status=0" file="cleaner.go:69" Command=bootstrap/run Component=ma-training-toolkit Platform=ModelArts-Service
time="2022-11-12T13:14:32+08:00" level=info msg="clean up child process succeed, pid=5616, wstatus=0, exit_status=0" file="cleaner.go:69" Command=bootstrap/run Component=ma-training-toolkit Platform=ModelArts-Service

Than you. Lamia

trnikon commented 1 year ago

I have come across this issue often, I think the problem is it's trying to load the whole training dataset to memory and it doesn't fit. I am using 4x3090 trying to train on resnet50 with Webface260M. When I split the data and only trained on the first 500k classes it worked without problems. There has to be a way to configure the dataloader to lazyload the data from disk but I don't have the skills to implement it.

bomcon123456 commented 1 year ago

Have you found a way to eliminate this? I think this is not about dataloader but about the classification head (2M identity) -> ~1.1B params (according to your log)

trnikon commented 1 year ago

Eventually I managed to prevent oom errors by making a 60GB swap file and training on 8xA100 with batch size 128

xxiMiaxx commented 1 year ago

Have you found a way to eliminate this? I think this is not about dataloader but about the classification head (2M identity) -> ~1.1B params (according to your log)

correct, the large number of identities creates a massive last fully connected layer, I have managed to train AdaFace on WebFace42M using this implementation of PartialFC by insightface.

TouchSkyWf commented 1 year ago

Have you found a way to eliminate this? I think this is not about dataloader but about the classification head (2M identity) -> ~1.1B params (according to your log)

correct, the large number of identities creates a massive last fully connected layer, I have managed to train AdaFace on WebFace42M using this implementation of PartialFC by insightface.

Can you open source the method of using partialfc to train adaface? Thank you so much!

shritor commented 7 months ago

Have you found a way to eliminate this? I think this is not about dataloader but about the classification head (2M identity) -> ~1.1B params (according to your log)

correct, the large number of identities creates a massive last fully connected layer, I have managed to train AdaFace on WebFace42M using this implementation of PartialFC by insightface.

can you send data webface42M or pretrained model for me. my email is kakashijin15@gmail.com Thank you so much!