Open David-Hari opened 1 year ago
It has something to do with the xFormers attention. Works when I disable it, but takes up more VRAM & time to train as a result.
Strangely, I did not get the nan error today. I ran the script twice (not resuming from checkpoint) and both times worked. So not sure what's going on there.
I would still like to know how the save_weights function is supposed to work though.
Just to note, I do still get the nan errors some times, especially when running for longer. So if anyone else has encountered this problem and has found a solution I would really like to know.
In the mean time, I will continue to experiment.
I updated the diffusers library to the latest version, I had to install from source (GitHub main branch) because the 0.11.1 release they published on pip does not include some of the newer num_cycles
and power
arguments to the get_scheduler
function.
That seemed to solve the loss error, though the resulting images aren't that great so I guess I'll have to play around with the parameters until I find something that works.
Maybe the nan problem is happening because I generate sample prompts at times during training. I was able to run it for 8000 steps without nan loss when I omitted save_sample_prompt
.
The question then is why does saving samples affect loss? Could it have something to do with torch.cuda.empty_cache()
? I found that I would occasionally run out of memory unless I did that both before and after using StableDiffusionPipeline to generate the samples. Maybe some part of the model being trained gets affected by generating the samples.
By the way, this is the code I am using to generate samples (pretty much the same as original code in save_weights
):
torch.cuda.empty_cache()
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder),
revision=args.revision,
)
pipeline = pipeline.to(accelerator.device)
g_cuda = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
pipeline.set_progress_bar_config(disable=True)
with torch.autocast('cuda'), torch.inference_mode():
for i in tqdm(range(len(prompts_list)), desc='Generating samples'):
sample_image = pipeline(
args.save_sample_prompts,
negative_prompt=args.save_sample_negative_prompt,
guidance_scale=args.save_guidance_scale,
num_inference_steps=args.save_infer_steps,
generator=g_cuda
).images[0]
sample_image.save(os.path.join(args.output_dir, f'{global_step}-{i}.png'))
del pipeline
torch.cuda.empty_cache()
Describe the bug
I tried running train_dreambooth.py on some of my own images and prompt, but each sample image was identical to the initial model (so for the prompt "a photo of xzv" I just got random images instead of the "xzv" subject I was training it to recognise). That could have just been because the learning rate was too low, but when I compared each snapshot they were binary identical, indicating that nothing in the neural network had changed. Could someone explain how the save_weights works and, assuming it works fine for everyone else, what I am doing wrong?
I also tried to update the script by merging the more recent changes from the original diffusers train_dreambooth.py. This was because the newest version can do checkpointing, but it does it in a different way to this one. I got it to run but I'm not sure if it works. When I generate sample images they are all completely black.
Here is my changes. I've still got some TODOs in there, so you can see the bits where I was not sure would work, but it should at least run.
Perhaps the problem of blank image is due to the loss going to nan at some point during training. I'm not sure what causes that, perhaps the learning rate is too high. When I reduced to 1e6 I was able to get at least 100 steps, but eventually it still does to nan. See the attached log of the
model_pred
andtarget
tensors at the lineloss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
Reproduction
Run the script linked in the description, using the following command line arguments:
I used about 20 sample images.
Logs
System Info
diffusers
version: 0.11.1