wutaiqiang / MoSLoRA

60 stars 7 forks source link

Low-quality image output from subject_driven_generation #3

Open vitrun opened 2 weeks ago

vitrun commented 2 weeks ago

Following the README.md, I tested the subject_driven_generation:

sh train_sdxl_lora_cat.sh
python3 infer.py

and got low-quality images from mixer model, while the vanilla lora remaining quite good. For example,

wutaiqiang commented 2 weeks ago

Would you be able to show your training log? So that I can check what happened.

vitrun commented 2 days ago

Hope this will help:

Namespace(pretrained_model_name_or_path='stabilityai/stable-diffusion-xl-base-1.0', pretrained_vae_model_name_or_path='madebyollin/sdxl-vae-fp16-fix', revision=None, variant=None, dataset_name=None, dataset_config_name=None, instance_data_dir='cat', cache_dir=None, image_column='image', caption_column=None, repeats=1, class_data_dir=None, instance_prompt='a photo of ikat cat', class_prompt=None, validation_prompt='A photo of ikat cat wearing a red hat', num_validation_images=4, validation_epochs=40, do_edm_style_training=False, with_prior_preservation=False, prior_loss_weight=1.0, num_class_images=100, output_dir='lora-trained-xl-mixer-cat', output_kohya_format=False, seed=0, resolution=1024, center_crop=False, random_flip=False, train_text_encoder=False, train_batch_size=1, sample_batch_size=4, num_train_epochs=1, max_train_steps=500, checkpointing_steps=500, checkpoints_total_limit=None, resume_from_checkpoint=None, gradient_accumulation_steps=4, gradient_checkpointing=False, learning_rate=0.0001, text_encoder_lr=5e-06, scale_lr=False, lr_scheduler='constant', snr_gamma=None, lr_warmup_steps=0, lr_num_cycles=1, lr_power=1.0, dataloader_num_workers=0, optimizer='AdamW', use_8bit_adam=False, adam_beta1=0.9, adam_beta2=0.999, prodigy_beta3=None, prodigy_decouple=True, adam_weight_decay=0.0001, adam_weight_decay_text_encoder=0.001, adam_epsilon=1e-08, prodigy_use_bias_correction=True, prodigy_safeguard_warmup=True, max_grad_norm=1.0, push_to_hub=False, hub_token=None, hub_model_id=None, logging_dir='logs', allow_tf32=False, report_to='wandb', mixed_precision='fp16', prior_generation_precision=None, local_rank=-1, enable_xformers_memory_efficient_attention=False, rank=4, use_dora=False, lora_use_mixer=True)
/root/miniconda3/envs/moslora_sd/lib/python3.10/site-packages/accelerate/accelerator.py:494: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
  self.scaler = torch.cuda.amp.GradScaler(**kwargs)
Detected kernel version 5.4.143, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
10/08/2024 12:11:38 - INFO - __main__ - Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: fp16

You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
{'variance_type', 'rescale_betas_zero_snr', 'dynamic_thresholding_ratio', 'thresholding', 'clip_sample_range'} was not found in config. Values will be initialized to default values.
{'use_quant_conv', 'latents_std', 'latents_mean', 'shift_factor', 'use_post_quant_conv', 'mid_block_add_attention'} was not found in config. Values will be initialized to default values.
{'dropout', 'reverse_transformer_layers_per_block', 'attention_type'} was not found in config. Values will be initialized to default values.

...

10/08/2024 12:12:30 - INFO - __main__ - ***** Running training *****
10/08/2024 12:12:30 - INFO - __main__ -   Num examples = 5
10/08/2024 12:12:30 - INFO - __main__ -   Num batches each epoch = 5
10/08/2024 12:12:30 - INFO - __main__ -   Num Epochs = 250
10/08/2024 12:12:30 - INFO - __main__ -   Instantaneous batch size per device = 1
10/08/2024 12:12:30 - INFO - __main__ -   Total train batch size (w. parallel, distributed & accumulation) = 4
10/08/2024 12:12:30 - INFO - __main__ -   Gradient Accumulation steps = 4
10/08/2024 12:12:30 - INFO - __main__ -   Total optimization steps = 500
Steps:   0%|▏                                                     | 2/500 [00:06<23:02,  2.78s/it, loss=0.134, lr=0.0001]{'feature_extractor', 'image_encoder'} was not found in config. Values will be initialized to default values.
                                                                                                                        {'final_sigmas_type', 'rescale_betas_zero_snr', 'sigma_min', 'sigma_max', 'timestep_type', 'use_exponential_sigmas'} was not found in config. Values will be initialized to default values.
Loaded scheduler as EulerDiscreteScheduler from `scheduler` subfolder of stabilityai/stable-diffusion-xl-base-1.0.
Loaded tokenizer as CLIPTokenizer from `tokenizer` subfolder of stabilityai/stable-diffusion-xl-base-1.0.
                                                                                                                        Loaded tokenizer_2 as CLIPTokenizer from `tokenizer_2` subfolder of stabilityai/stable-diffusion-xl-base-1.0.0, 59.73it/s]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████| 7/7 [00:00<00:00, 40.59it/s]
10/08/2024 12:12:40 - INFO - __main__ - Running validation...
 Generating 4 images with prompt: A photo of ikat cat wearing a red hat.
{'algorithm_type', 'final_sigmas_type', 'variance_type', 'solver_type', 'rescale_betas_zero_snr', 'dynamic_thresholding_ratio', 'lower_order_final', 'thresholding', 'euler_at_final', 'solver_order', 'use_lu_lambdas', 'lambda_min_clipped'} was not found in config. Values will be initialized to default values.
Steps:   0%|▏                                                    | 2/500 [01:29<23:02,  2.78s/it, loss=0.0181, lr=0.0001]wandb: WARNING Tried to log to step 2 that is less than the current step 3. Steps must be monotonically increasing, so this data will be ignored. See https://wandb.me/define-metric to log data out of order.
wandb: WARNING Tried to log to step 2 that is less than the current step 3. Steps must be monotonically increasing, so this data will be ignored. See https://wandb.me/define-metric to log data out of order.
Steps:   1%|▌                                                   | 6/500 [01:39<1:36:44, 11.75s/it, loss=0.244, lr=0.0001]wandb: WARNING Tried to log to step 2 that is less than the current step 3. Steps must be monotonically increasing, so this data will be ignored. See https://wandb.me/define-metric to log data out of order.
Steps:  16%|████████▋                                            | 82/500 [04:40<14:58,  2.15s/it, loss=0.155, lr=0.0001]{'feature_extractor', 'image_encoder'} was not found in config. Values will be initialized to default values.
                                                                                                                        {'final_sigmas_type', 'rescale_betas_zero_snr', 'sigma_min', 'sigma_max', 'timestep_type', 'use_exponential_sigmas'} was not found in config. Values will be initialized to default values.
Loaded scheduler as EulerDiscreteScheduler from `scheduler` subfolder of stabilityai/stable-diffusion-xl-base-1.0.
Loaded tokenizer as CLIPTokenizer from `tokenizer` subfolder of stabilityai/stable-diffusion-xl-base-1.0.
Loaded tokenizer_2 as CLIPTokenizer from `tokenizer_2` subfolder of stabilityai/stable-diffusion-xl-base-1.0.
Loading pipeline components...: 100%|██████████████████████████████████████████████████████| 7/7 [00:00<00:00, 43.05it/s]
10/08/2024 12:17:14 - INFO - __main__ - Running validation... █████████████████████████████| 7/7 [00:00<00:00, 43.18it/s]
 Generating 4 images with prompt: A photo of ikat cat wearing a red hat.
{'algorithm_type', 'final_sigmas_type', 'variance_type', 'solver_type', 'rescale_betas_zero_snr', 'dynamic_thresholding_ratio', 'lower_order_final', 'thresholding', 'euler_at_final', 'solver_order', 'use_lu_lambdas', 'lambda_min_clipped'} was not found in config. Values will be initialized to default values.
Steps:  17%|████████▍                                        | 86/500 [06:09<1:10:30, 10.22s/it, loss=0.00562, lr=0.0001]wandb: WARNING Tried to log to step 82 that is less than the current step 83. Steps must be monotonically increasing, so this data will be ignored. See https://wandb.me/define-metric to log data out of order.
wandb: WARNING Tried to log to step 82 that is less than the current step 83. Steps must be monotonically increasing, so this data will be ignored. See https://wandb.me/define-metric to log data out of order.
wandb: WARNING Tried to log to step 82 that is less than the current step 83. Steps must be monotonically increasing, so this data will be ignored. See https://wandb.me/define-metric to log data out of order.
Steps:  32%|████████████████▊                                   | 162/500 [09:10<12:08,  2.16s/it, loss=0.106, lr=0.0001]{'feature_extractor', 'image_encoder'} was not found in config. Values will be initialized to default values.
                                                                                                                        {'final_sigmas_type', 'rescale_betas_zero_snr', 'sigma_min', 'sigma_max', 'timestep_type', 'use_exponential_sigmas'} was not found in config. Values will be initialized to default values.
Loaded scheduler as EulerDiscreteScheduler from `scheduler` subfolder of stabilityai/stable-diffusion-xl-base-1.0.
Loaded tokenizer as CLIPTokenizer from `tokenizer` subfolder of stabilityai/stable-diffusion-xl-base-1.0.
Loaded tokenizer_2 as CLIPTokenizer from `tokenizer_2` subfolder of stabilityai/stable-diffusion-xl-base-1.0.
Loading pipeline components...: 100%|██████████████████████████████████████████████████████| 7/7 [00:00<00:00, 46.12it/s]
10/08/2024 12:21:43 - INFO - __main__ - Running validation... █████████████████████████████| 7/7 [00:00<00:00, 46.25it/s]
 Generating 4 images with prompt: A photo of ikat cat wearing a red hat.
{'algorithm_type', 'final_sigmas_type', 'variance_type', 'solver_type', 'rescale_betas_zero_snr', 'dynamic_thresholding_ratio', 'lower_order_final', 'thresholding', 'euler_at_final', 'solver_order', 'use_lu_lambdas', 'lambda_min_clipped'} was not found in config. Values will be initialized to default values.

...

Model weights saved in lora-trained-xl-mixer-cat/pytorch_lora_weights.safetensors
{'use_quant_conv', 'latents_std', 'latents_mean', 'shift_factor', 'use_post_quant_conv', 'mid_block_add_attention'} was not found in config. Values will be initialized to default values.
{'feature_extractor', 'image_encoder'} was not found in config. Values will be initialized to default values.
                                                                                                                        Loaded text_encoder_2 as CLIPTextModelWithProjection from `text_encoder_2` subfolder of stabilityai/stable-diffusion-xl-base-1.0.
                                                                                                                        {'dropout', 'reverse_transformer_layers_per_block', 'attention_type'} was not found in config. Values will be initialized to default values.
Loaded unet as UNet2DConditionModel from `unet` subfolder of stabilityai/stable-diffusion-xl-base-1.0.
                                                                                                                        Loaded text_encoder as CLIPTextModel from `text_encoder` subfolder of stabilityai/stable-diffusion-xl-base-1.0.  3.60s/it]
                                                                                                                        {'final_sigmas_type', 'rescale_betas_zero_snr', 'sigma_min', 'sigma_max', 'timestep_type', 'use_exponential_sigmas'} was not found in config. Values will be initialized to default values.
Loaded scheduler as EulerDiscreteScheduler from `scheduler` subfolder of stabilityai/stable-diffusion-xl-base-1.0.
Loaded tokenizer as CLIPTokenizer from `tokenizer` subfolder of stabilityai/stable-diffusion-xl-base-1.0.
Loaded tokenizer_2 as CLIPTokenizer from `tokenizer_2` subfolder of stabilityai/stable-diffusion-xl-base-1.0.
Loading pipeline components...: 100%|██████████████████████████████████████████████████████| 7/7 [00:07<00:00,  1.03s/it]
Loading unet.ine components...: 100%|██████████████████████████████████████████████████████| 7/7 [00:07<00:00,  1.49it/s]

Loading adapter weights from state_dict led to unexpected keys not found in the model:  ['down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.lora_AB.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.lora_AB.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.lora_AB.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.lora_AB.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.lora_AB.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.lora_AB.default_0.weight', 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.lora_AB.default_0.weight', 'down_blocks.1.attentions

...

and the loss graph: image

I believe the issue stems from Loading adapter weights from state_dict led to unexpected keys not found in the model:.... This warning persists when invoking infer.py.

vitrun commented 2 days ago

Following patch is a quick hack. It updates diffusers to include support for MoSLoRA, effectively addressing the previously mentioned issue:

--- a/src/diffusers/utils/peft_utils.py
+++ b/src/diffusers/utils/peft_utils.py
@@ -180,7 +180,9 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
     # layer names without the Diffusers specific
     target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
     use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
+    use_mixer = any("lora_AB" in k for k in peft_state_dict)

     lora_config_kwargs = {
         "r": r,
         "lora_alpha": lora_alpha,
@@ -188,6 +190,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
         "alpha_pattern": alpha_pattern,
         "target_modules": target_modules,
         "use_dora": use_dora,
+        "lora_use_mixer": use_mixer,
     }
     return lora_config_kwargs

However, I observed no significant improvements with MoSLoRA compared to the vanilla LoRA. For the prompt 'A_picture_of_a_cube_shaped_ikat_cat', MoSLoRA produced an image that neither accurately reflects the prompt nor maintains the subject's appearance: image

wutaiqiang commented 2 days ago

hi, glad to see that the bug is fixed.

I want to provide more details about my results (check more in https://arxiv.org/pdf/2406.11909v3):

Input images: image

Generated image: image

the left image is more like a cube shape and remains a cat.


Based on the image you provide:

image

Lora's result is a cube but no longer a cat. Also, my moslora's result and yours are similar being fat (more similar to cube).

By the way, the generation is quite random and evaluation is subjective. In MosLoRA, we perform human evaluation on the generated images (details in Section 4.3). We randomly select 8 prompts to generate the corresponding images. After that, 15 human experts are asked to independently score win/tie/loss for the paired images from LoRA and MoSLoRA.

vitrun commented 2 days ago

Thanks for your reply. Perhaps a larger amount of evaluation is needed to clarify the issue. Another question to clarify, from the printed logs, it seems that MoSLoRA was only applied to the Linear layers of the Attention module, is that correct? Have you tried implementing MoSLoRA on modules other than Attention, as well as on Conv2D and Embedding layers? I wonder how effective they are.

wutaiqiang commented 2 days ago

it seems that MoSLoRA was only applied to the Linear layers of the Attention module, is that correct?

MoSLoRA is applied to the linear layers of attention and FFN layers.

Have you tried implementing MoSLoRA on modules other than Attention, as well as on Conv2D and Embedding layers?

Not yet. but the implementation is easy to achieve. If you want, you can insert a mixer modifying these functions: https://github.com/wutaiqiang/MoSLoRA/blob/bba79b5491742df44e740e7348a265a0954da903/subject_driven_generation/peft/tuners/lora/layer.py#L819 and https://github.com/wutaiqiang/MoSLoRA/blob/bba79b5491742df44e740e7348a265a0954da903/subject_driven_generation/peft/tuners/lora/layer.py#L595 following https://github.com/wutaiqiang/MoSLoRA/blob/bba79b5491742df44e740e7348a265a0954da903/subject_driven_generation/peft/tuners/lora/layer.py#L365

vitrun commented 1 day ago

Hello, I delved into the paper on FLoRA and noticed that, specifically concerning the Linear layer, FLoRA and MoSLORA exhibit remarkable similarities.

FLoRA: image

MoSLORA: image

Based on your previous response:

MoSLoRA is applied to the linear layers of attention and FFN layers.

And from the comparison presented in your paper: image

It's observed that MoSLoRA, when fine-tuning only the Linear layers, outperforms FLoRA—which fine-tunes both Linear and Conv2D layers(please correct me if my assumption is incorrect )—in a nearly identical approach concerning the Linear layer.

Could you provide some insight into how this phenomenon should be interpreted?

wutaiqiang commented 1 day ago

Thanks for your reminder.

FLoRA and MoSLORA exhibit remarkable similarities.

Yes, FLoRA and MoSLoRA are concurrent work (<1 month in arxiv) regarding inserting core/mixer into LoRA. We have compared them in the latest version and the camera-ready version. Please refer to the related work section https://arxiv.org/pdf/2406.11909v3.

image image

differences.

The motivation, ideas, and datasets are different. Also, the initialization for the mixer/core is different. When fine-tuning LLaMA, there are only linear layers and no conv layers.