huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
26.27k stars 5.42k forks source link

Textual Inversion Broken: it updates entire `embeddings` weights, loss diverges #507

Closed freckletonj closed 1 year ago

freckletonj commented 2 years ago

Describe the bug

The weights of the entire text_encoder evolve over the course of training, thus breaking the text_encoder. I'm not sure why yet, but this in turn breaks Inversion.

To demonstrate it,

1.) save a random token id and it's embedding, outside the main loop:

    token_embed_w_copy = text_encoder.get_input_embeddings().weight.data.clone().detach().requires_grad_(False).to(accelerator.device)

    # never seen
    test_tok_id = tokenizer.convert_tokens_to_ids('alligator')
    test_tok = token_embed_w_copy[test_tok_id]

2.) Inside the loop, assert that it's not changing:

test_tok_actual = text_encoder.get_input_embeddings().weight.data[test_tok_id]
assert(torch.allclose(test_tok, test_tok_actual))
# BREAKS!

The assertion passes until an entire batch completes, at which time the embeddings diverge.

The code currently tries to solve this by zeroing all the non-placeholder_token gradients to zero, but this (or something else) fails to keep the weights from updating.

I've confirmed that this breaks TI by manually copying back the entire set of non-placeholder weights after every batch, and this fixes TI. But it's ducttape, really, and I'm hoping someone has a better idea.

EDIT: this does not actually solve it. It solves it a little, it seems, but the loss still random-walks / diverges. I can even 0 out all the gradient each step and it still behaves strangely.

System Info

Debian, Python 3.9.2, revision b2b3b1a8ab83b020ecaf32f45de3ef23644331cf

freckletonj commented 2 years ago

I closed because I mistakenly thought there wasn't actually a bug. Reopening bc I still think there is.

JunnYu commented 2 years ago

that is caused by "weight_decay",when "weight_decay=0" that is right。 image

here is a small demo

import torch
import torch.nn as nn
import numpy
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW

torch.manual_seed(100)
numpy.random.seed(100)

class SimpleNet(nn.Module):
    def __init__(self, input_size, output_size):
        super(SimpleNet, self).__init__()
        self.linear = nn.Linear(input_size, output_size)

    def forward(self, x):
        x = self.linear(x)
        return x

mse = nn.MSELoss()
epochs = 2
input_size = 8192  
output_size = 8192  
batch_size = 2048   
nums_batch = 10

class RandomDataset(Dataset):
    def __init__(self, num_samples):
        self.num_samples = num_samples

    def __getitem__(self, idx):
        data = numpy.random.random([input_size]).astype('float32')
        label = numpy.random.random([output_size]).astype('float32')
        return torch.from_numpy(data).cuda(), torch.from_numpy(label).cuda()

    def __len__(self):
        return self.num_samples

dataset = RandomDataset(nums_batch * batch_size)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=0)

model = SimpleNet(input_size, output_size)

# we need change this
# WEIGHT_DECAY = 0
WEIGHT_DECAY = 0.01
optimizer = AdamW(lr=0.0001, params=model.parameters(), weight_decay=WEIGHT_DECAY) 

model.train()
model.cuda()

clone_weight = model.linear.weight.data.clone()

for epoch in range(epochs):
    for i, (data, label) in enumerate(loader):

        output = model(data)
        loss = mse(output, label)
        loss.backward()

        grads = model.linear.weight.grad
        index_grads_to_zero = (torch.arange(8192) != 8191).cuda()
        grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)

        with torch.no_grad():
            diff_zero = (clone_weight[:-1] - model.linear.weight.data[:-1]).abs().sum()
            diff_nonzero = (clone_weight[-1] - model.linear.weight.data[-1]).abs().sum()
        print("before optimizer", diff_zero.item(), diff_nonzero.item())
        optimizer.step()
        with torch.no_grad():
            diff_zero = (clone_weight[:-1] - model.linear.weight.data[:-1]).abs().sum()
            diff_nonzero = (clone_weight[-1] - model.linear.weight.data[-1]).abs().sum()
        print("after optimizer", diff_zero.item(), diff_nonzero.item())
        optimizer.zero_grad()
        print("="*50)

    break

WEIGHT_DECAY = 0.01 error

image

WEIGHT_DECAY = 0 right 【only optimize the last】

image
patrickvonplaten commented 2 years ago

@patil-suraj could you take a look here?

freckletonj commented 2 years ago

I've confirmed this problem does not seem to exist in: https://github.com/nicolai256/Stable-textual-inversion_win

patil-suraj commented 2 years ago

That's a great catch @freckletonj and @JunnYu, taking a look.

patil-suraj commented 2 years ago

You are right @JunnYu ! The weight_decay indeed updates the whole embeddings. Will send a fix soon.

isamu-isozaki commented 2 years ago

Hi. I was able to do this by something like

with torch.no_grad():
    text_encoder.get_input_embeddings().weight[~grad_mask, :] -= lr_scheduler.get_last_lr()[0]*args.adam_weight_decay*text_encoder.get_input_embeddings().weight[~grad_mask, :]

where grad mask is the mask for the gradient

isamu-isozaki commented 2 years ago

Actually the above is wrong with accelerate since it'll do that update multiple times. But it should work if we do

if accelerator.sync_gradients:
    with torch.no_grad():
        text_encoder.get_input_embeddings().weight[~grad_mask, :] -= lr_scheduler.get_last_lr()[0]*args.adam_weight_decay*text_encoder.get_input_embeddings().weight[~grad_mask, :]

altho it's not the full thing adamw was supposed to do

patil-suraj commented 2 years ago

Working on the fix, should ready by end of the week, sorry to get back to this only now!

patrickvonplaten commented 2 years ago

re-pinging @patil-suraj here :-)

patrickvonplaten commented 2 years ago

Another ping @patil-suraj

patrickvonplaten commented 1 year ago

I found some time to push a fix: https://github.com/huggingface/diffusers/pull/1665#issue-1492039634

@JunnYu , feel free to give it a look.