huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
26.42k stars 5.44k forks source link

A strange thing happened when I wrote my own code to train Cotrolnet_sdxl, as soon as I did the first backpropagation, noise_pred became nan. #9422

Closed Li-Zn-H closed 2 months ago

Li-Zn-H commented 2 months ago

Describe the bug

A strange thing happened when I wrote my own code to train cotrolnet, as soon as I did the first backpropagation, noise_pred became nan. I did a lot of debugging, gradient decay, mixed precision training, removing ema and other parts, but the result was always nan once backpropagation was applied

Reproduction

my model and dataset setting

tokenizer = CLIPTokenizer.from_pretrained(sd_path, subfolder='tokenizer')
tokenizer_2 = CLIPTokenizer.from_pretrained(sd_path, subfolder='tokenizer_2')
text_encoder = CLIPTextModel.from_pretrained(sd_path, subfolder='text_encoder', torch_dtype=torch.float16).to(device)
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(sd_path, subfolder='text_encoder_2', torch_dtype=torch.float16).to(device)
vae = AutoencoderKL.from_pretrained("/data2/lixq22/models/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)
unet = UNet2DConditionModel.from_pretrained(sd_path, subfolder='unet', torch_dtype=torch.float16).to(device)
controlnet = ControlNetModel.from_pretrained(cn_path, torch_dtype=torch.float16).to(device)
scheduler = DDIMScheduler.from_pretrained(sd_path, subfolder="scheduler")
text_encoder.requires_grad_(False)
text_encoder_2.requires_grad_(False)
unet.requires_grad_(False)
vae.requires_grad_(False)
controlnet.train()

controlnet = DDP(controlnet, device_ids=[rank])
optimizer = torch.optim.AdamW(controlnet.parameters(), lr= 1e-5, betas=(0.9, 0.999), weight_decay=0.01, eps=1e-8)
dataset = MyDataset()

train_sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)

train_loader = DataLoader(
    dataset, 
    batch_size=2, 
    collate_fn=collate_fn, 
    sampler=train_sampler, 
    num_workers=1, 
    pin_memory=True
)
lr_scheduler = OneCycleLR(optimizer, max_lr=1e-4, total_steps=total_steps, pct_start=0.1, anneal_strategy='cos')
ema = ExponentialMovingAverage(controlnet.parameters(), decay=0.995)

# my forword and backpropagation code
for epoch in range(num_epochs):
    train_sampler.set_epoch(epoch)
    controlnet.train()
    optimizer.zero_grad(set_to_none="store_true")
    epoch_loss = 0
    tokenizers = [tokenizer, tokenizer_2]
    text_encoders = [text_encoder, text_encoder_2]
    for i, data in enumerate(train_loader):
        data = {k: (v.to(device).to(torch.float16) if isinstance(v, torch.Tensor) else v) for k, v in data.items()}
        with torch.no_grad():
            prompt_embeds_list = []
            for prompt, tokenizer, text_encoder in zip([data['prompts'], data['prompts']], tokenizers, text_encoders):
                text_input_ids = tokenizer(
                    prompt,
                    padding="max_length",
                    max_length=tokenizer.model_max_length,
                    truncation=True,
                    return_tensors="pt",
                ).input_ids

                prompt_embeds = text_encoder(
                    text_input_ids.to(device),
                    output_hidden_states=True)
                pooled_prompt_embeds = prompt_embeds[0]  # [b,1280]
                prompt_embeds = prompt_embeds.hidden_states[-2]  # 
                prompt_embeds_list.append(prompt_embeds)
            prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)  #  [b,77,2048]

            add_time_ids = list((1024, 1024) + (0, 0) + (1024, 1024))
            add_time_ids = torch.tensor([add_time_ids], dtype=torch.float16).to(device).repeat(len(data['prompts']), 1)  # [b,6]

            latents = vae.encode(data['pixel_values']).latent_dist.sample() * 0.18215   # [b,4,128,128]
            controlnet_image = data['conditioning_pixel_values']   # [b,3,1024,1024]
            bsz = len(latents)
            timesteps = torch.randint(0, 1000, (bsz,), device=latents.device).long()
        noise = torch.randn_like(latents)  # [b,4,128,128]
        noisy_latents = scheduler.add_noise(latents, noise, timesteps)  # [b,4,128,128]

        down_block_res_samples, mid_block_res_sample = controlnet(
            noisy_latents,
            timesteps,
            encoder_hidden_states=prompt_embeds,
            added_cond_kwargs={"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids},
            controlnet_cond=controlnet_image,
            return_dict=False,
        )
        model_pred = unet(
            noisy_latents,
            timesteps,
            encoder_hidden_states=prompt_embeds,
            down_block_additional_residuals=[sample for sample in down_block_res_samples],
            mid_block_additional_residual=mid_block_res_sample,
            added_cond_kwargs={"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids},
            return_dict=False,
        )[0]

        mse_loss .backward()

        if (i + 1) % accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(controlnet.parameters(), max_norm=1.0)
            optimizer.step()
            ema.update(controlnet.parameters())
            lr_scheduler.step()
            optimizer.zero_grad()
            if rank == 0:
                current_lr = lr_scheduler.get_last_lr()[0]
                writer.add_scalar('Loss/train', weighted_loss.item(), global_step)
                writer.add_scalar('Learning Rate', current_lr, global_step)
                writer.flush()
            global_step += 1
            ema.copy_to(controlnet.parameters())
        epoch_loss += mse_loss .item()

Logs

No response

System Info

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

Who can help?

No response

Li-Zn-H commented 2 months ago

At the same time, I did the experiment again. After backpropagation, prompt_embeds, add_time_ids, and pooled_prompt_embeds are normal

Li-Zn-H commented 2 months ago

The first occurrence of nan is during the computation of controlnet after the first backpropagation down_block_res_samples, mid_block_res_sample = controlnet( noisy_latents, timesteps, encoder_hidden_states=prompt_embeds, added_cond_kwargs={"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}, controlnet_cond=controlnet_image, return_dict=False, )

asomoza commented 2 months ago

cc: @sayakpaul

sayakpaul commented 2 months ago

This should likely be a discussion and not an issue because on the example dataset folks were able to train successfully.

There could be many reasons for this kind of behaviour but the first thing I would try is overfit a single batch of data.

Li-Zn-H commented 2 months ago

This should likely be a discussion and not an issue because on the example dataset folks were able to train successfully.

There could be many reasons for this kind of behaviour but the first thing I would try is overfit a single batch of data.

I'm sorry I submitted it in the wrong place, because this is also the first time I've had a problem that hasn't been solved for two days. Just now I tried it again, and I found that it works when I specify float32 for all my models and variables, whereas before I've been using float16, or tried mixed precision training, All encountered the bugs I described (well, bugs, 🤦‍).