kohya-ss / sd-scripts

Apache License 2.0
5.3k stars 880 forks source link

--resume decreases training speed by 4-10x #1708

Open Tophness opened 1 month ago

Tophness commented 1 month ago

Reliably goes from 10s/it when training to 40s-100s/it on this configuration, and if I remember correctly it was even adding over 200s/it on others. I'm spilling over into shared vram a lot, so I'm guessing this is because it's holding on to the previous state in memory after load. Literally the only parameter changed is ''--resume".

accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path "consolidated_s6700.safetensors" --clip_l "clip_l.safetensors" --t5xxl "t5xxl_fp16.safetensors" --ae "ae.safetensors" --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 16 --optimizer_type bitsandbytes.optim.AdEMAMix8bit --learning_rate 1e-4 --fp8_base --max_train_epochs 400 --save_every_n_epochs 1 --model.toml --output_dir outputs --output_name model --timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0 --resume "outputs/model-000001-state" --save_state --network_args "loraplus_unet_lr_ratio=4" "train_t5xxl=True" "split_qkv=True"

Tophness commented 1 month ago

If this is fixed I'll close #1700 since I'm able to train for a while before OOM, and whatever the recent memory changes are have increased my regular speed from ~70s/it to ~10s/it. If I can autoresume after crash, it works out much faster anyway.

I realise I'm probably the only one trying to push this much on a 16gb card, but I need text encoder training and better optimisers to get decent results on this concept and I'm happy to wait longer.

feffy380 commented 1 month ago

I think I'm encountering the same issue with a different optimizer (schedulefree): training works fine initially, but resuming an existing run uses just enough extra memory to cause OOM during sample generation. Sounds like a memory leak involving the resume state

kohya-ss commented 1 month ago

Maybe Accelerate's resuming or the optimizer's load_state doesn't fully restore the optimizer state. If that's the case, there may be little we can do about it... But I'll take a look at it when I have time.

Tophness commented 1 month ago

Maybe Accelerate's resuming or the optimizer's load_state doesn't fully restore the optimizer state.

"Not restoring the optimizer state" does sound a lot like what's happening. I trained 8 epochs at 10s/it and resumed up to 17 epochs at ~40s/it, resumed from 17 to 24 and it shot down to ~21s/it, then resumed from 24 and it went back down to ~10s/it.

Edit: Since then, I added an extra 32gb ram to my system (which is was using all of) and resumed. It went down to 5s/it resuming from 24 to epoch 45. Resuming from epoch 45 went back up to >40s/it, so I tried to get it back down by stopping at epoch 45 and resuming, then 46 and resuming, still at >40s/it.

But the weirdest part is when I resumed at 8, it said I was at epoch 0, then after 1 iteration jumped up to 9 and continued incrementing as normal. When I resumed at 17, it said it was starting at epoch 7, then at 24 It said it was start at 17, so it went from 17 back up to 24, literally writing state 17-24, and it overwrote the old state 24 I was resuming from. I paused at the new state 24, started again from the new state 24, and it's now saying "epoch is incremented. current_epoch: 0, epoch: 9". When it loaded the training state it did say "train_state.json: {'current_epoch': 24, 'current_step': 2608}", but then it says 9 after and starts at 0/127792. There are 130400 steps in total and 326 steps / epoch, so 130400-127792=2608/326=8 epochs. Is this expected behaviour? Is the step count just an arbitrary number after resuming or is it actually starting again at an earlier epoch, possibly retraining over the same weights? Cos that would explain some of the outputs I'm getting and possibly the memory issues. Edit: Just found #1559 so I guess the random step/epoch count is expected behaviour.