Closed freckletonj closed 1 year ago
I closed because I mistakenly thought there wasn't actually a bug. Reopening bc I still think there is.
that is caused by "weight_decay",when "weight_decay=0" that is right。
here is a small demo
import torch
import torch.nn as nn
import numpy
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
torch.manual_seed(100)
numpy.random.seed(100)
class SimpleNet(nn.Module):
def __init__(self, input_size, output_size):
super(SimpleNet, self).__init__()
self.linear = nn.Linear(input_size, output_size)
def forward(self, x):
x = self.linear(x)
return x
mse = nn.MSELoss()
epochs = 2
input_size = 8192
output_size = 8192
batch_size = 2048
nums_batch = 10
class RandomDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __getitem__(self, idx):
data = numpy.random.random([input_size]).astype('float32')
label = numpy.random.random([output_size]).astype('float32')
return torch.from_numpy(data).cuda(), torch.from_numpy(label).cuda()
def __len__(self):
return self.num_samples
dataset = RandomDataset(nums_batch * batch_size)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=0)
model = SimpleNet(input_size, output_size)
# we need change this
# WEIGHT_DECAY = 0
WEIGHT_DECAY = 0.01
optimizer = AdamW(lr=0.0001, params=model.parameters(), weight_decay=WEIGHT_DECAY)
model.train()
model.cuda()
clone_weight = model.linear.weight.data.clone()
for epoch in range(epochs):
for i, (data, label) in enumerate(loader):
output = model(data)
loss = mse(output, label)
loss.backward()
grads = model.linear.weight.grad
index_grads_to_zero = (torch.arange(8192) != 8191).cuda()
grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
with torch.no_grad():
diff_zero = (clone_weight[:-1] - model.linear.weight.data[:-1]).abs().sum()
diff_nonzero = (clone_weight[-1] - model.linear.weight.data[-1]).abs().sum()
print("before optimizer", diff_zero.item(), diff_nonzero.item())
optimizer.step()
with torch.no_grad():
diff_zero = (clone_weight[:-1] - model.linear.weight.data[:-1]).abs().sum()
diff_nonzero = (clone_weight[-1] - model.linear.weight.data[-1]).abs().sum()
print("after optimizer", diff_zero.item(), diff_nonzero.item())
optimizer.zero_grad()
print("="*50)
break
@patil-suraj could you take a look here?
I've confirmed this problem does not seem to exist in: https://github.com/nicolai256/Stable-textual-inversion_win
That's a great catch @freckletonj and @JunnYu, taking a look.
You are right @JunnYu ! The weight_decay indeed updates the whole embeddings
. Will send a fix soon.
Hi. I was able to do this by something like
with torch.no_grad():
text_encoder.get_input_embeddings().weight[~grad_mask, :] -= lr_scheduler.get_last_lr()[0]*args.adam_weight_decay*text_encoder.get_input_embeddings().weight[~grad_mask, :]
where grad mask is the mask for the gradient
Actually the above is wrong with accelerate since it'll do that update multiple times. But it should work if we do
if accelerator.sync_gradients:
with torch.no_grad():
text_encoder.get_input_embeddings().weight[~grad_mask, :] -= lr_scheduler.get_last_lr()[0]*args.adam_weight_decay*text_encoder.get_input_embeddings().weight[~grad_mask, :]
altho it's not the full thing adamw was supposed to do
Working on the fix, should ready by end of the week, sorry to get back to this only now!
re-pinging @patil-suraj here :-)
Another ping @patil-suraj
I found some time to push a fix: https://github.com/huggingface/diffusers/pull/1665#issue-1492039634
@JunnYu , feel free to give it a look.
Describe the bug
The weights of the entire
text_encoder
evolve over the course of training, thus breaking thetext_encoder
. I'm not sure why yet, but this in turn breaks Inversion.To demonstrate it,
1.) save a random token id and it's embedding, outside the main loop:
2.) Inside the loop, assert that it's not changing:
The assertion passes until an entire batch completes, at which time the embeddings diverge.
The code currently tries to solve this by zeroing all the non-
placeholder_token
gradients to zero, but this (or something else) fails to keep the weights from updating.I've confirmed that this breaks TI by manually copying back the entire set of non-
placeholder
weights after every batch, and this fixes TI. But it's ducttape, really, and I'm hoping someone has a better idea.EDIT: this does not actually solve it. It solves it a little, it seems, but the loss still random-walks / diverges. I can even 0 out all the gradient each step and it still behaves strangely.
System Info
Debian, Python 3.9.2, revision b2b3b1a8ab83b020ecaf32f45de3ef23644331cf