huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
26.33k stars 5.42k forks source link

EMA training for PEFT LoRAs #9998

Open bghira opened 2 days ago

bghira commented 2 days ago

Is your feature request related to a problem? Please describe.

EMAModel in Diffusers is not plumbed for interacting well with PEFT LoRAs, which leaves users to implement their own.

The idea has been thrown around that LoRA did not benefit from EMA, and research papers had shown this. However, after curiosity piqued, took a bit but managed to make it work.

Here is a pull request for SimpleTuner where I've updated my EMAModel implementation to behave more like nn.Module and allow EMAModel to be passed into more processes without "funny business".

This spot in the save hooks was hardcoded to take the class name following Diffusers convention but we can do more dynamic approach in perhaps a training_utils helper method.

Just a bit downward at L208 in the save hooks, I did something I'm not really 100% happy with, but users were:

The tricky part is the 2nd copy of the EMA model that gets saved in the standard LoRA format:

        if self.args.use_ema:
            # we'll temporarily overwrite teh LoRA parameters with the EMA parameters to save it.
            logger.info("Saving EMA model to disk.")
            trainable_parameters = [
                p
                for p in self._primary_model().parameters()
                if p.requires_grad
            ]
            self.ema_model.store(trainable_parameters)
            self.ema_model.copy_to(trainable_parameters)
            if self.transformer is not None:
                self.pipeline_class.save_lora_weights(
                    os.path.join(output_dir, "ema"),
                    transformer_lora_layers=convert_state_dict_to_diffusers(
                        get_peft_model_state_dict(self._primary_model())
                    ),
                )
            elif self.unet is not None:
                self.pipeline_class.save_lora_weights(
                    os.path.join(output_dir, "ema"),
                    unet_lora_layers=convert_state_dict_to_diffusers(
                        get_peft_model_state_dict(self._primary_model())
                    ),
                )
            self.ema_model.restore(trainable_parameters)

this could probably be done more nicely with a trainable_parameters() method on the model classes where appropriate.

I guess the decorations with converting state dicts are required for now, but it would be ideal if this could be simplified so that newcomers do not have to look into and understand so many moving pieces.

For quantised training, we have to quantise the EMA model just like the trained model had done to it.

The validations were kind of a pain but I wanted to make the EMA load/unload possible to do during the process repeatedly so that each prompt can be validated for the ckpt as well as the EMA weights. Here is my method for enabling (and just below, disabling) the EMA model at inference time.

However, the effect is really nice; here you see the starting SD 3.5M on the left, the trained LoRA in the centre, and EMA on the right.

image

image

image

image

image

these samples are from 60,000 steps of training a rank-128 PEFT LoRA on all of the attn layers for the SD 3.5 Medium model on ~120,000 high quality photos.

while it's not a cure-all for training problems, throughout the entire duration of training, the EMA model has outperformed the trained checkpoint.

It'd be a good idea to consider someday including EMA for LoRA with related improvements for saving/loading EMA weights on adapters so that users can receive better results from the training examples. I don't think the validation changes are needed, but they can be done in a non-intrusive way, more nicely than I have done here.

bghira commented 2 days ago

cc @linoytsaban @sayakpaul for your interest perhaps

sayakpaul commented 1 day ago

Thanks for the interesting thread.

I think for now we can refer the users to SimpleTuner for this. Also, perhaps, it's subjective but I don't necessarily find the EMA results to be better than what's without.

bghira commented 1 day ago

yeah the centre's outputs are actually entirely incoherent. don't know why that is preferred