ChenDarYen / Key-Locked-Rank-One-Editing-for-Text-to-Image-Personalization

An Pytorch implementation of the paper Key-Locked Rank One Editing for Text-to-Image Personalization
MIT License
76 stars 7 forks source link

Very long training time #13

Open PaulToast opened 11 months ago

PaulToast commented 11 months ago

I'm not super experienced and don't know if something is actually going wrong..

I managed to get the training going and am seeing results that go somewhat in the right direction. But the training is taking quite long (especially in comparison to the speed promised in the paper), and the quality of the trained concepts is very lacking.

After about 1h of training it has not finished the first epoch. It says something like "Epoch 0: (2276/6250)" - do those numbers refer to the training timesteps? Because in the config file the max number of steps is set to 400, and i do reach that number appearently, because it outputs the step=400.ckpt model when i let it run all the way through.

Is something going wrong here?

I'm using a NVIDIA GeForce RTX 3060 with around 12 GB ram, CUDA is working fine.

PaulToast commented 11 months ago

Here's the output i'm getting while training.


Running on GPUs 0,
Perfusion: Running in eps-prediction mode
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 1024 and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 1024 and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 1024 and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 1024 and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is 1024 and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is 1024 and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is 1024 and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is 1024 and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is 1024 and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is 1024 and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 1024 and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 1024 and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 1024 and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 1024 and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 1024 and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 1024 and using 5 heads.
DiffusionWrapper has 867.83 M params.
making attention of type 'vanilla-xformers' with 512 in_channels
building MemoryEfficientAttnBlock with 512 in_channels...
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla-xformers' with 512 in_channels
building MemoryEfficientAttnBlock with 512 in_channels...
Restored from ./ckpt/v2-1_512-ema-pruned.ckpt with 38 missing and 2 unexpected keys
Missing Keys:
 ['logvar', 'C_inv', 'target_input', 'model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.target_output', 'model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.target_output', 'model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.target_output', 'model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.target_output', 'model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.target_output', 'model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.target_output', 'model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.target_output', 'model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.target_output', 'model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.target_output', 'model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.target_output', 'model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.target_output', 'model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.target_output', 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_k.target_output', 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_v.target_output', 'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.target_output', 'model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v.target_output', 'model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k.target_output', 'model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v.target_output', 'model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k.target_output', 'model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v.target_output', 'model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_k.target_output', 'model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_v.target_output', 'model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_k.target_output', 'model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_v.target_output', 'model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_k.target_output', 'model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_v.target_output', 'model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_k.target_output', 'model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_v.target_output', 'model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_k.target_output', 'model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_v.target_output', 'model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_k.target_output', 'model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_v.target_output', 'embedding_manager.string_to_param_dict.*', 'embedding_manager.initial_embeddings.*', 'embedding_manager.get_embedding_for_tkn.weight']

Unexpected Keys:
 ['model_ema.decay', 'model_ema.num_updates']
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
accumulate_grad_batches = 4
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type                   | Params
-------------------------------------------------------------
0 | model             | DiffusionWrapper       | 867 M 
1 | first_stage_model | AutoencoderKL          | 83.7 M
2 | cond_stage_model  | FrozenOpenCLIPEmbedder | 354 M 
3 | embedding_manager | EmbeddingManager       | 50.6 M
-------------------------------------------------------------
961 K     Trainable params
1.3 B     Non-trainable params
1.3 B     Total params
2,611.042 Total estimated model params size (MB)
Save project config
Save lightning config
Epoch 0:   0%|                                                                                                        | 0/5000 [00:00<?, ?it/s]Data shape for DDIM sampling is (4, 4, 64, 64), eta 0.0
Running DDIM Sampling with 50 timesteps
DDIM Sampler: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:22<00:00,  2.21it/s]
Epoch 0:   0%| | 20/5000 [00:59<4:08:41,  3.00s/it, loss=0.0518, v_num=, train/loss_simple_step=0.0522, train/loss_vlb_step=0.000219, train/los
SlZeroth commented 11 months ago

me too

edufschmidt commented 9 months ago

@SlZeroth @PaulToast I'm experiencing the same thing, even when running on a relatively beefy GPU. were you able to speed up training somehow?

PaulToast commented 8 months ago

@SlZeroth @PaulToast I'm experiencing the same thing, even when running on a relatively beefy GPU. were you able to speed up training somehow?

Well, mainly i was given a more powerful graphics card haha. I don't think the training was going that slow after all, the numbers are a little misleading - its not supposed to go through multiple epochs, it just goes on until it reaches the max_steps set in the config. You will probably never have to train further than step=600 or something

edufschmidt commented 8 months ago

thanks @PaulToast, yea makes sense. I eventually figured it out after I found there was a mistake in how I was passing the args to the training script 🤦🏻‍♂️