Closed pedrocolon93 closed 3 years ago
The code in question is the following:
g_opt, d_opt = self.optimizers()
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
## Train with all-real batch
real_label = torch.ones((len(train_batch["text"]), 1)).to(self.discriminator.device)
fake_label = torch.zeros((len(train_batch["text"]), 1)).to(self.discriminator.device)
only_text = self.generator_tok(train_batch["text"],
max_length=self.max_source_length,
padding="max_length",
truncation=True,
return_tensors="pt")
if "mem" in train_batch:
update_mem = []
with self.generator_tok.as_target_tokenizer():
rework = np.array(train_batch["mem"])
try:
rework = rework.squeeze(1)
except:
pass
rework = rework.transpose()
for mem_set in rework:
dec_enc_relevant_rels = self.generator_tok(mem_set.tolist(), max_length=128,
padding="max_length", truncation=True)
update_mem.append(torch.tensor(dec_enc_relevant_rels["input_ids"],device=self.generator.device))
update_mem = torch.stack(update_mem,0)
else:
update_mem = None
if batch_idx>self.discriminator_freeze_iters:
self.discriminator.zero_grad()
# Format batch
real_batch = []
for i in range(len(train_batch["text"])):
# Condition on the input_text and the actual generation
real_batch.append(train_batch["text"][i]+self.discriminator.tokenizer.sep_token+train_batch["relation"][i])
real_cpu = self.discriminator_tok(real_batch,
max_length=self.max_target_length+self.max_source_length+1,
padding="max_length",
truncation=True,
return_tensors="pt")
for i in real_cpu:
real_cpu[i] = real_cpu[i].to(self.discriminator.device)
# Forward pass real batch through D
output = self.discriminator(real_cpu)
# Calculate loss on all-real batch
errD_real = F.binary_cross_entropy_with_logits(output, real_label)
# Calculate gradients for D in backward pass
# self.manual_backward(errD_real, d_opt)
errD_real.backward()
D_x = output.mean().item()
## Train with all-fake batch
# Generate batch of latent vectors
# Generate fake image batch with G
for i in only_text:
only_text[i] = only_text[i].to(self.generator.device)
fake = self.generator.generate(
input_ids=only_text['input_ids'],
attention_mask=only_text["attention_mask"],
max_length=self.max_target_length,
num_beams=1,
do_sample=False,
update_mem=update_mem,
# use_mem=True,
# clear_mem=True,
)
fake_texts = []
for i in range(fake.shape[0]):
gids = fake[i, :].tolist()
s = self.generator_tok.decode(gids, skip_special_tokens=False, clean_up_tokenization_spaces=True) # Replace end of sentence stuff
fake_texts.append(s)
fake_batch = []
for i in range(len(train_batch["text"])):
# Condition on the input_text and the actual generation
fake_batch.append(train_batch["text"][i]+self.discriminator.tokenizer.sep_token+fake_texts[i])
fake_cpu = self.discriminator_tok(fake_batch,
max_length=self.max_target_length + self.max_source_length + 1,
padding="max_length",
truncation=True,
return_tensors="pt")
for i in fake_cpu:
fake_cpu[i] = fake_cpu[i].to(self.discriminator.device)
# Classify all fake batch with D
output = self.discriminator(fake_cpu) # Should be detached by here.
# Calculate D's loss on the all-fake batch
errD_fake = F.binary_cross_entropy_with_logits(output, fake_label)
# Calculate the gradients for this batch, accumulated (summed) with previous gradients
errD_fake.backward()
# self.manual_backward(errD_fake, d_opt)
D_G_z1 = output.mean().item()
# Compute error of D as sum over the fake and the real batches
items = []
for i in range(len(train_batch["text"])):
# Condition on the input_text and the actual generation
items.append(train_batch["text"][i]+self.discriminator.tokenizer.sep_token+train_batch["relation"][i])
amount_of_falses = int(len(items))
# <subj> dust <obj> the refrigerator <\/relation>
subjects = [item[item.rfind("<subj>") + 6:item.rfind("<obj>")] for item in items]
objects = [item[item.rfind("<obj>") + 5:item.rfind("</relation>")] for item in items]
random.shuffle(subjects)
random.shuffle(objects)
falses = []
n, p = 1, .5 # n = coins flipped, p = prob of success
for i in range(amount_of_falses):
flip = np.random.binomial(n, p)
samp = random.sample(items, 1)[0]
if flip == 1: # subj
subj = random.sample(subjects, 1)[0]
samp = samp[:samp.rfind("<subj>") + 6] + subj + samp[samp.rfind("<obj>"):]
else:
obj = random.sample(objects, 1)[0]
samp = samp[:samp.rfind("<obj>") + 5] + obj + samp[samp.rfind("</relation>"):]
falses.append(samp)
all = items + falses
conf_y = [[1] for i in range(len(items))] + [[0] for i in range(len(falses))]
conf_y = torch.tensor(conf_y,device=self.discriminator.device)
real_cpu = self.discriminator_tok(all,
max_length=self.max_target_length+self.max_source_length+1,
padding="max_length",
truncation=True,
return_tensors="pt")
for i in real_cpu:
real_cpu[i] = real_cpu[i].to(self.discriminator.device)
# Forward pass real batch through D
output = self.discriminator(real_cpu)
# Calculate loss on all-real batch
errD_conf = F.binary_cross_entropy_with_logits(output, conf_y.float())
# Calculate gradients for D in backward pass
# self.manual_backward(errD_conf, d_opt)
errD_conf.backward()
errD = errD_real + errD_fake + errD_conf
# Update D
d_opt.step()
# self.manual_optimizer_step(d_opt, force_optimizer_step=True, custom_args=4)
else:
errD = torch.tensor([0])
errD_real = torch.tensor([0])
errD_fake = torch.tensor([0])
errD_conf = torch.tensor([0])
print("Skipping d update")
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
self.generator.zero_grad()
# Format the targets for loss calc
padding = "max_length"
targets = train_batch["relation"]
# Setup the tokenizer for targets
with self.generator_tok.as_target_tokenizer():
labels = self.generator_tok(targets, max_length=self.max_target_length,
padding=padding, truncation=True)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
if padding == "max_length" and ignore_pad_token_for_loss:
labels["input_ids"] = [
[(l if l != self.generator_tok.pad_token_id else -100) for l in label] for label in labels["input_ids"]
]
only_text["labels"] = torch.tensor(labels["input_ids"])
for i in only_text:
only_text[i] = only_text[i].to(self.generator.device)
# Normal LM loss for the gen
lm_loss = self.generator(**only_text,update_mem=update_mem,
use_mem=True,
clear_mem=True)
if batch_idx>self.discriminator_freeze_iters:
# fake = self.generator.generate(
# input_ids=only_text['input_ids'],
# attention_mask=only_text["attention_mask"],
# max_length=self.max_target_length,
# do_sample=False
# )
logits_processor = LogitsProcessorList([
MinLengthLogitsProcessor(15, eos_token_id=self.generator.config.eos_token_id),
])
encoder_outputs = self.generator.model.encoder(only_text['input_ids'], return_dict=True, output_hidden_states=True)
decoder_input_ids = torch.tensor([[self.generator.config.decoder_start_token_id] for x in range(len(train_batch["text"]))]).to(self.generator.device)
fake = self.generator.greedy_search(decoder_input_ids,
encoder_outputs=encoder_outputs,
logits_processor=logits_processor,
max_length=self.max_target_length,
output_scores=True,
return_dict_in_generate=True,
update_mem=update_mem
)
vectors = []
#Need to do on the cpu for memory.
for s in fake["scores"]:
number_of_gpus = 16
sms = nn.functional.softmax(torch.clamp(self.scale,0,1000)*s)
C_split = torch.split(self.discriminator.encoder.shared.weight, self.discriminator.encoder.shared.weight.shape[1] // number_of_gpus, dim=1)
# loop over the four GPUs, and perform the calculation on each using the corresponding chunk of `C`
D_split = []
for i in range(number_of_gpus):
# device = 'cuda:{:d}'.format(i)
D_split.append(sms @ C_split[i])
# DO THIS ONLY IF YOU HAVE ENOUGH CPU MEMORY!! :
embeddings = torch.cat([d.cpu() for d in D_split], dim=1)
# embeddings = torch.matmul(sms.cpu(),self.discriminator.encoder.shared.weight.cpu()*self.discriminator.encoder.encoder.embed_scale)
# embeddings = torch.bmm(sms,self.discriminator.encoder.shared.weight*self.discriminator.encoder.encoder.embed_scale)
vectors.append(embeddings)
fake_vectors = torch.stack(vectors,1).to(self.discriminator.device) # We have the 127 vectors that correspond to the generated output,
self.discriminator.encoder.shared.weight.to(self.discriminator.device)
# now we need to join that with the vectors that correspond to the conditioned text and the separator
bos_vec = self.discriminator.encoder.shared(torch.tensor([self.discriminator_tok.bos_token_id]).to(self.discriminator.device))*self.discriminator.encoder.encoder.embed_scale
eos_vec = self.discriminator.encoder.shared(torch.tensor([self.discriminator_tok.eos_token_id]).to(self.discriminator.device))*self.discriminator.encoder.encoder.embed_scale
sep_vec = self.discriminator.encoder.shared(torch.tensor([self.discriminator_tok.sep_token_id]).to(self.discriminator.device))*self.discriminator.encoder.encoder.embed_scale
pad_vec = self.discriminator.encoder.shared(torch.tensor([self.discriminator_tok.pad_token_id]).to(self.discriminator.device))*self.discriminator.encoder.encoder.embed_scale
fake_texts = []
for i in range(fake_vectors.shape[0]):
# print(fake_vectors[i,:,:].shape,i,fake_vectors.shape[0])
gids = fake["sequences"][i, :].tolist()
to_remove = [j for j in range(0,len(gids)) if gids[j]==self.discriminator_tok.pad_token_id]+\
[j for j in range(0,len(gids)) if gids[j] == self.discriminator_tok.bos_token_id] +\
[j for j in range(0,len(gids)) if gids[j] == self.discriminator_tok.eos_token_id] +\
[j for j in range(0,len(gids)) if gids[j] == self.discriminator_tok.sep_token_id]
to_remove = sorted(list(set(to_remove)))
to_keep = [j for j in range(0,len(gids)-1) if j not in to_remove]
# print('TO KEEP',to_keep)
# print('NOT TO KEEP',to_remove)
# print(fake_vectors[i,to_keep,:].shape)
fake_texts.append(fake_vectors[i,to_keep,:])
fake_batch = []
for i in range(len(train_batch["text"])):
# Condition on the input_text and the actual generation
fake_batch.append(self.discriminator.encoder.shared(
self.discriminator_tok(train_batch["text"][i],
add_special_tokens=False,
truncation=True,
max_length=self.max_source_length,
return_tensors="pt")["input_ids"].to(self.discriminator.device)
)*self.discriminator.encoder.encoder.embed_scale)
final_fake_batch = {
"inputs_embeds":[],
"attention_mask":[]
}
for i in range(len(train_batch["text"])):
vecs = [bos_vec,fake_batch[i].squeeze(0),sep_vec,fake_texts[i],eos_vec]
vecs = torch.cat(vecs,0)
tot = vecs.shape[0]
att = [1 for i in range(0,tot)]+[0 for i in range(0,self.max_target_length+self.max_source_length+1-tot)]
vecs = torch.cat([vecs]+[pad_vec for i in range(0,self.max_target_length+self.max_source_length+1-tot)],0)
att = torch.tensor(att).to(self.discriminator.device)
final_fake_batch["inputs_embeds"].append(vecs)
final_fake_batch["attention_mask"].append(torch.tensor(att))
for i in final_fake_batch:
final_fake_batch[i] = torch.stack(final_fake_batch[i],0)
final_fake_batch[i] = final_fake_batch[i].to(self.discriminator.device)
# Since we just updated D, perform another forward pass of all-fake batch through D
output = self.discriminator(final_fake_batch)
# Calculate G's loss based on this output
critic_loss = F.binary_cross_entropy_with_logits(output, real_label)
errG = critic_loss + lm_loss[0]
else:
critic_loss = torch.tensor([0])
errG = lm_loss[0]
# Calculate gradients for G
errG.backward()
# self.manual_backward(errG, g_opt)
# Update G
g_opt.step()
In particular,
whenever I switch from
backward(...)
to
self.manual_backward(error, opt)
it is considerably slower (even for FP16 calculations)
Are you using multiple gpus ? If yes, self.manual_backward takes care of gradient synchronisation which adds an overhead.
Ah! I am! Now I am confused, if I don't use the self.manual_backward, then that sync does not happen? What would this cause?
and as a follow up what would you recommend for training this? I have 1 machine with 4 gpus
For anyone having this issue, I recommend reading through this: https://pytorch-lightning.readthedocs.io/en/latest/guides/speed.html and https://pytorch-lightning.readthedocs.io/en/stable/advanced/multi_gpu.html This sped up a lot.
Dear @pedrocolon93,
Ah! I am! Now I am confused, if I don't use the self.manual_backward, then that sync does not happen? What would this cause?
This is rather complicated, but let me try to explain :)
Here are the internals of DistributedDataParallel forward function.
https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/distributed.py#L112
def forward(self, *inputs, **kwargs):
with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
self.reducer.save_thread_local_state()
if torch.is_grad_enabled() and self.require_backward_grad_sync:
self.logger.set_runtime_stats_and_log()
self.num_iterations += 1
self.reducer.prepare_for_forward()
# Notify the join context that this process has not joined, if
# needed
work = _Join.notify_join_context(self)
if work:
self.reducer._set_forward_pass_work_handle(
work, self._divide_by_initial_world_size
)
# Calling _rebuild_buckets before forward compuation,
# It may allocate new buckets before deallocating old buckets
# inside _rebuild_buckets. To save peak memory usage,
# call _rebuild_buckets before the peak memory usage increases
# during forward computation.
# This should be called only once during whole training period.
if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
logging.info("Reducer buckets have been rebuilt in this iteration.")
if self.require_forward_param_sync:
self._sync_params()
if self._join_config.enable:
# Notify joined ranks whether they should sync in backwards pass or not.
self._check_global_requires_backward_grad_sync(is_joined_rank=False)
if self.device_ids:
inputs, kwargs = self.to_kwargs(inputs, kwargs, self.device_ids[0])
output = self.module(*inputs[0], **kwargs[0])
else:
output = self.module(*inputs, **kwargs)
if torch.is_grad_enabled() and self.require_backward_grad_sync:
self.require_forward_param_sync = True
# We'll return the output object verbatim since it is a freeform
# object. We need to find any tensors in this object, though,
# because we need to figure out which parameters were used during
# this forward pass, to ensure we short circuit reduction for any
# unused parameters. Only if `find_unused_parameters` is set.
if self.find_unused_parameters and not self.static_graph:
# Do not need to populate this for static graph.
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
self.reducer.prepare_for_backward([])
else:
self.require_forward_param_sync = False
# TODO. Right now we add this sink for static_graph training only. once
# this feature is stable, we will add this sink for all cases. E.g.
# This sink can help capture more accuracte backward start time as well.
if self.static_graph and self.num_iterations == 1:
# Need to grab list of tensors from user output in order to pass
# to custom autograd function.
output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref(
output
)
output_placeholders = [None for _ in range(len(output_tensor_list))]
# Do not touch tensors that have no grad_fn, which can cause issues
# such as https://github.com/pytorch/pytorch/issues/60733
for i, output in enumerate(output_tensor_list):
if torch.is_tensor(output) and output.grad_fn is None:
output_placeholders[i] = output
passthrough_tensor_list = _DDPSink.apply(self.reducer, *output_tensor_list)
for i in range(len(output_placeholders)):
if output_placeholders[i] is None:
output_placeholders[i] = passthrough_tensor_list[i]
# Reconstruct output data structure.
output = _tree_unflatten_with_rref(
output_placeholders, treespec, output_is_rref
)
return output
The logic to make sure the Reducer
knows about the current graph so it can properly perform reduction is there
if torch.is_grad_enabled() and self.require_backward_grad_sync:
self.require_forward_param_sync = True
# We'll return the output object verbatim since it is a freeform
# object. We need to find any tensors in this object, though,
# because we need to figure out which parameters were used during
# this forward pass, to ensure we short circuit reduction for any
# unused parameters. Only if `find_unused_parameters` is set.
if self.find_unused_parameters and not self.static_graph:
# Do not need to populate this for static graph.
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
self.reducer.prepare_for_backward([])
else:
self.require_forward_param_sync = False
However, in manual optimization mode, you are performing backward and optimizer call directly within the forward function of your model within those lines and the reducer hasn't seen the output yet (traditionally the loss).
if self.device_ids:
inputs, kwargs = self.to_kwargs(inputs, kwargs, self.device_ids[0])
output = self.module(*inputs[0], **kwargs[0])
else:
output = self.module(*inputs, **kwargs)
Therefore, if you call loss.backward
and then optimizer.step
, the gradients won't be reduced across processes and your model weights will start diverging.
However, self.manual_backward
is internally providing the loss to the reducer, so reduction is done on backward call.
I hope it helps to clarify why you need to use self.manual_backward
and where the extra time is being spent.
I would recommend to use manual optimization
whenever you have special needs for your optimization which doesn't fit automatic optimisation.
Best, T.C
Dear @pedrocolon93,
I will be closing this issue. Feel free to re-open if you have more questions.
Best, T.C
Excellent!! Thanks for the clarification!
🐛 Bug
I am using a GAN like system that is using manual backward to do the backward calculation the losses for the discriminator and generator separately. I've noticed that it is considerably slower than the regular backward. Is there anything I am missing?
To Reproduce
I can upload a version of the code later, but I was wondering if anyone had this issue.
Expected behavior
Regular speed backward, but in this case it is extremely slow.
Environment
Additional context