Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.51k stars 3.39k forks source link

manual_backward Slower than backward #8629

Closed pedrocolon93 closed 3 years ago

pedrocolon93 commented 3 years ago

🐛 Bug

I am using a GAN like system that is using manual backward to do the backward calculation the losses for the discriminator and generator separately. I've noticed that it is considerably slower than the regular backward. Is there anything I am missing?

To Reproduce

I can upload a version of the code later, but I was wondering if anyone had this issue.

Expected behavior

Regular speed backward, but in this case it is extremely slow.

Environment

Additional context

pedrocolon93 commented 3 years ago

The code in question is the following:

       g_opt, d_opt = self.optimizers()
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        real_label = torch.ones((len(train_batch["text"]), 1)).to(self.discriminator.device)
        fake_label = torch.zeros((len(train_batch["text"]), 1)).to(self.discriminator.device)
        only_text = self.generator_tok(train_batch["text"],
                                       max_length=self.max_source_length,
                                       padding="max_length",
                                       truncation=True,
                                       return_tensors="pt")
        if "mem" in train_batch:
            update_mem = []
            with self.generator_tok.as_target_tokenizer():
                rework = np.array(train_batch["mem"])
                try:
                    rework = rework.squeeze(1)
                except:
                    pass
                rework = rework.transpose()
                for mem_set in rework:
                    dec_enc_relevant_rels = self.generator_tok(mem_set.tolist(), max_length=128,
                                                           padding="max_length", truncation=True)

                    update_mem.append(torch.tensor(dec_enc_relevant_rels["input_ids"],device=self.generator.device))
            update_mem = torch.stack(update_mem,0)
        else:
            update_mem = None

        if batch_idx>self.discriminator_freeze_iters:
            self.discriminator.zero_grad()
            # Format batch
            real_batch = []
            for i in range(len(train_batch["text"])):
                # Condition on the input_text and the actual generation
                real_batch.append(train_batch["text"][i]+self.discriminator.tokenizer.sep_token+train_batch["relation"][i])

            real_cpu = self.discriminator_tok(real_batch,
                                              max_length=self.max_target_length+self.max_source_length+1,
                                              padding="max_length",
                                              truncation=True,
                                              return_tensors="pt")
            for i in real_cpu:
                real_cpu[i] = real_cpu[i].to(self.discriminator.device)
            # Forward pass real batch through D
            output = self.discriminator(real_cpu)
            # Calculate loss on all-real batch
            errD_real = F.binary_cross_entropy_with_logits(output, real_label)
            # Calculate gradients for D in backward pass
            # self.manual_backward(errD_real, d_opt)
            errD_real.backward()
            D_x = output.mean().item()

            ## Train with all-fake batch
            # Generate batch of latent vectors

            # Generate fake image batch with G

            for i in only_text:
                only_text[i] = only_text[i].to(self.generator.device)
            fake = self.generator.generate(
                input_ids=only_text['input_ids'],
                attention_mask=only_text["attention_mask"],
                max_length=self.max_target_length,
                num_beams=1,
                do_sample=False,
                update_mem=update_mem,
                # use_mem=True,
                # clear_mem=True,
            )
            fake_texts = []
            for i in range(fake.shape[0]):
                gids = fake[i, :].tolist()
                s = self.generator_tok.decode(gids, skip_special_tokens=False, clean_up_tokenization_spaces=True) # Replace end of sentence  stuff
                fake_texts.append(s)

            fake_batch = []
            for i in range(len(train_batch["text"])):
                # Condition on the input_text and the actual generation
                fake_batch.append(train_batch["text"][i]+self.discriminator.tokenizer.sep_token+fake_texts[i])
            fake_cpu = self.discriminator_tok(fake_batch,
                                              max_length=self.max_target_length + self.max_source_length + 1,
                                              padding="max_length",
                                              truncation=True,
                                              return_tensors="pt")
            for i in fake_cpu:
                fake_cpu[i] = fake_cpu[i].to(self.discriminator.device)
            # Classify all fake batch with D
            output = self.discriminator(fake_cpu) # Should be detached by here.
            # Calculate D's loss on the all-fake batch

            errD_fake =  F.binary_cross_entropy_with_logits(output, fake_label)
            # Calculate the gradients for this batch, accumulated (summed) with previous gradients
            errD_fake.backward()
            # self.manual_backward(errD_fake, d_opt)

            D_G_z1 = output.mean().item()
            # Compute error of D as sum over the fake and the real batches
            items = []
            for i in range(len(train_batch["text"])):
                # Condition on the input_text and the actual generation
                items.append(train_batch["text"][i]+self.discriminator.tokenizer.sep_token+train_batch["relation"][i])

            amount_of_falses = int(len(items))
            # <subj> dust <obj> the refrigerator <\/relation>
            subjects = [item[item.rfind("<subj>") + 6:item.rfind("<obj>")] for item in items]
            objects = [item[item.rfind("<obj>") + 5:item.rfind("</relation>")] for item in items]
            random.shuffle(subjects)
            random.shuffle(objects)
            falses = []
            n, p = 1, .5  # n = coins flipped, p = prob of success

            for i in range(amount_of_falses):
                flip = np.random.binomial(n, p)
                samp = random.sample(items, 1)[0]
                if flip == 1:  # subj
                    subj = random.sample(subjects, 1)[0]
                    samp = samp[:samp.rfind("<subj>") + 6] + subj + samp[samp.rfind("<obj>"):]
                else:
                    obj = random.sample(objects, 1)[0]
                    samp = samp[:samp.rfind("<obj>") + 5] + obj + samp[samp.rfind("</relation>"):]

                falses.append(samp)
            all = items + falses
            conf_y = [[1] for i in range(len(items))] + [[0] for i in range(len(falses))]
            conf_y = torch.tensor(conf_y,device=self.discriminator.device)
            real_cpu = self.discriminator_tok(all,
                                              max_length=self.max_target_length+self.max_source_length+1,
                                              padding="max_length",
                                              truncation=True,
                                              return_tensors="pt")
            for i in real_cpu:
                real_cpu[i] = real_cpu[i].to(self.discriminator.device)
            # Forward pass real batch through D
            output = self.discriminator(real_cpu)
            # Calculate loss on all-real batch
            errD_conf = F.binary_cross_entropy_with_logits(output, conf_y.float())
            # Calculate gradients for D in backward pass
            # self.manual_backward(errD_conf, d_opt)
            errD_conf.backward()

            errD = errD_real + errD_fake + errD_conf
        # Update D
            d_opt.step()
            # self.manual_optimizer_step(d_opt, force_optimizer_step=True, custom_args=4)
        else:
            errD = torch.tensor([0])
            errD_real = torch.tensor([0])
            errD_fake = torch.tensor([0])
            errD_conf = torch.tensor([0])
            print("Skipping d update")

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        self.generator.zero_grad()
        # Format the targets for loss calc
        padding = "max_length"
        targets = train_batch["relation"]
        # Setup the tokenizer for targets
        with self.generator_tok.as_target_tokenizer():
            labels = self.generator_tok(targets, max_length=self.max_target_length,
                                        padding=padding, truncation=True)

        # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
        # padding in the loss.
        if padding == "max_length" and ignore_pad_token_for_loss:
            labels["input_ids"] = [
                [(l if l != self.generator_tok.pad_token_id else -100) for l in label] for label in labels["input_ids"]
            ]

        only_text["labels"] = torch.tensor(labels["input_ids"])
        for i in only_text:
            only_text[i] = only_text[i].to(self.generator.device)

        # Normal LM loss for the gen

        lm_loss = self.generator(**only_text,update_mem=update_mem,
                use_mem=True,
                clear_mem=True)
        if batch_idx>self.discriminator_freeze_iters:
            # fake = self.generator.generate(
            #     input_ids=only_text['input_ids'],
            #     attention_mask=only_text["attention_mask"],
            #     max_length=self.max_target_length,
            #     do_sample=False
            # )
            logits_processor = LogitsProcessorList([
                MinLengthLogitsProcessor(15, eos_token_id=self.generator.config.eos_token_id),
                ])
            encoder_outputs = self.generator.model.encoder(only_text['input_ids'], return_dict=True, output_hidden_states=True)
            decoder_input_ids = torch.tensor([[self.generator.config.decoder_start_token_id] for x in range(len(train_batch["text"]))]).to(self.generator.device)
            fake = self.generator.greedy_search(decoder_input_ids,
                                                encoder_outputs=encoder_outputs,
                                                logits_processor=logits_processor,
                                                max_length=self.max_target_length,
                                                output_scores=True,
                                                return_dict_in_generate=True,
                                                update_mem=update_mem
                                                )
            vectors = []
            #Need to do on the cpu for memory.
            for s in fake["scores"]:
                number_of_gpus = 16
                sms = nn.functional.softmax(torch.clamp(self.scale,0,1000)*s)
                C_split = torch.split(self.discriminator.encoder.shared.weight, self.discriminator.encoder.shared.weight.shape[1] // number_of_gpus, dim=1)

                # loop over the four GPUs, and perform the calculation on each using the corresponding chunk of `C`
                D_split = []
                for i in range(number_of_gpus):
                    # device = 'cuda:{:d}'.format(i)
                    D_split.append(sms @ C_split[i])

                # DO THIS ONLY IF YOU HAVE ENOUGH CPU MEMORY!! :
                embeddings = torch.cat([d.cpu() for d in D_split], dim=1)
                # embeddings = torch.matmul(sms.cpu(),self.discriminator.encoder.shared.weight.cpu()*self.discriminator.encoder.encoder.embed_scale)
                # embeddings = torch.bmm(sms,self.discriminator.encoder.shared.weight*self.discriminator.encoder.encoder.embed_scale)
                vectors.append(embeddings)
            fake_vectors = torch.stack(vectors,1).to(self.discriminator.device) # We have the 127 vectors that correspond to the generated output,
            self.discriminator.encoder.shared.weight.to(self.discriminator.device)
            # now we need to join that with the vectors that correspond to the conditioned text and the separator
            bos_vec = self.discriminator.encoder.shared(torch.tensor([self.discriminator_tok.bos_token_id]).to(self.discriminator.device))*self.discriminator.encoder.encoder.embed_scale
            eos_vec = self.discriminator.encoder.shared(torch.tensor([self.discriminator_tok.eos_token_id]).to(self.discriminator.device))*self.discriminator.encoder.encoder.embed_scale
            sep_vec = self.discriminator.encoder.shared(torch.tensor([self.discriminator_tok.sep_token_id]).to(self.discriminator.device))*self.discriminator.encoder.encoder.embed_scale
            pad_vec = self.discriminator.encoder.shared(torch.tensor([self.discriminator_tok.pad_token_id]).to(self.discriminator.device))*self.discriminator.encoder.encoder.embed_scale

            fake_texts = []
            for i in range(fake_vectors.shape[0]):
                # print(fake_vectors[i,:,:].shape,i,fake_vectors.shape[0])
                gids = fake["sequences"][i, :].tolist()
                to_remove = [j for j in range(0,len(gids)) if gids[j]==self.discriminator_tok.pad_token_id]+\
                [j for j in range(0,len(gids)) if gids[j] == self.discriminator_tok.bos_token_id] +\
                [j for j in range(0,len(gids)) if gids[j] == self.discriminator_tok.eos_token_id] +\
                [j for j in range(0,len(gids)) if gids[j] == self.discriminator_tok.sep_token_id]
                to_remove = sorted(list(set(to_remove)))
                to_keep = [j for j in range(0,len(gids)-1) if j not in to_remove]
                # print('TO KEEP',to_keep)
                # print('NOT TO KEEP',to_remove)
                # print(fake_vectors[i,to_keep,:].shape)
                fake_texts.append(fake_vectors[i,to_keep,:])

            fake_batch = []
            for i in range(len(train_batch["text"])):
                # Condition on the input_text and the actual generation
                fake_batch.append(self.discriminator.encoder.shared(
                    self.discriminator_tok(train_batch["text"][i],
                                                         add_special_tokens=False,
                                                         truncation=True,
                                                         max_length=self.max_source_length,
                                                         return_tensors="pt")["input_ids"].to(self.discriminator.device)
                                                                    )*self.discriminator.encoder.encoder.embed_scale)

            final_fake_batch = {
                "inputs_embeds":[],
                "attention_mask":[]
            }
            for i in range(len(train_batch["text"])):
                vecs = [bos_vec,fake_batch[i].squeeze(0),sep_vec,fake_texts[i],eos_vec]
                vecs = torch.cat(vecs,0)
                tot = vecs.shape[0]
                att = [1 for i in range(0,tot)]+[0 for i in range(0,self.max_target_length+self.max_source_length+1-tot)]
                vecs = torch.cat([vecs]+[pad_vec for i in range(0,self.max_target_length+self.max_source_length+1-tot)],0)
                att = torch.tensor(att).to(self.discriminator.device)
                final_fake_batch["inputs_embeds"].append(vecs)
                final_fake_batch["attention_mask"].append(torch.tensor(att))

            for i in final_fake_batch:
                final_fake_batch[i] =  torch.stack(final_fake_batch[i],0)
                final_fake_batch[i] = final_fake_batch[i].to(self.discriminator.device)

            # Since we just updated D, perform another forward pass of all-fake batch through D
            output = self.discriminator(final_fake_batch)
            # Calculate G's loss based on this output
            critic_loss =  F.binary_cross_entropy_with_logits(output, real_label)
            errG = critic_loss + lm_loss[0]
        else:
            critic_loss = torch.tensor([0])
            errG = lm_loss[0]
        # Calculate gradients for G
        errG.backward()
        # self.manual_backward(errG, g_opt)

        # Update G
        g_opt.step()
pedrocolon93 commented 3 years ago

In particular, whenever I switch from backward(...) to self.manual_backward(error, opt) it is considerably slower (even for FP16 calculations)

tchaton commented 3 years ago

Are you using multiple gpus ? If yes, self.manual_backward takes care of gradient synchronisation which adds an overhead.

pedrocolon93 commented 3 years ago

Ah! I am! Now I am confused, if I don't use the self.manual_backward, then that sync does not happen? What would this cause?

pedrocolon93 commented 3 years ago

and as a follow up what would you recommend for training this? I have 1 machine with 4 gpus

pedrocolon93 commented 3 years ago

For anyone having this issue, I recommend reading through this: https://pytorch-lightning.readthedocs.io/en/latest/guides/speed.html and https://pytorch-lightning.readthedocs.io/en/stable/advanced/multi_gpu.html This sped up a lot.

tchaton commented 3 years ago

Dear @pedrocolon93,

Ah! I am! Now I am confused, if I don't use the self.manual_backward, then that sync does not happen? What would this cause?

This is rather complicated, but let me try to explain :)

Here are the internals of DistributedDataParallel forward function.

https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/distributed.py#L112

    def forward(self, *inputs, **kwargs):
        with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
            self.reducer.save_thread_local_state()
            if torch.is_grad_enabled() and self.require_backward_grad_sync:
                self.logger.set_runtime_stats_and_log()
                self.num_iterations += 1
                self.reducer.prepare_for_forward()

            # Notify the join context that this process has not joined, if
            # needed
            work = _Join.notify_join_context(self)
            if work:
                self.reducer._set_forward_pass_work_handle(
                    work, self._divide_by_initial_world_size
                )

            # Calling _rebuild_buckets before forward compuation,
            # It may allocate new buckets before deallocating old buckets
            # inside _rebuild_buckets. To save peak memory usage,
            # call _rebuild_buckets before the peak memory usage increases
            # during forward computation.
            # This should be called only once during whole training period.
            if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
                logging.info("Reducer buckets have been rebuilt in this iteration.")

            if self.require_forward_param_sync:
                self._sync_params()

            if self._join_config.enable:
                # Notify joined ranks whether they should sync in backwards pass or not.
                self._check_global_requires_backward_grad_sync(is_joined_rank=False)

            if self.device_ids:
                inputs, kwargs = self.to_kwargs(inputs, kwargs, self.device_ids[0])
                output = self.module(*inputs[0], **kwargs[0])
            else:
                output = self.module(*inputs, **kwargs)

            if torch.is_grad_enabled() and self.require_backward_grad_sync:
                self.require_forward_param_sync = True
                # We'll return the output object verbatim since it is a freeform
                # object. We need to find any tensors in this object, though,
                # because we need to figure out which parameters were used during
                # this forward pass, to ensure we short circuit reduction for any
                # unused parameters. Only if `find_unused_parameters` is set.
                if self.find_unused_parameters and not self.static_graph:
                    # Do not need to populate this for static graph.
                    self.reducer.prepare_for_backward(list(_find_tensors(output)))
                else:
                    self.reducer.prepare_for_backward([])
            else:
                self.require_forward_param_sync = False

        # TODO. Right now we add this sink for static_graph training only. once
        # this feature is stable, we will add this sink for all cases. E.g.
        # This sink can help capture more accuracte backward start time as well.
        if self.static_graph and self.num_iterations == 1:
            # Need to grab list of tensors from user output in order to pass
            # to custom autograd function.
            output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref(
                output
            )
            output_placeholders = [None for _ in range(len(output_tensor_list))]
            # Do not touch tensors that have no grad_fn, which can cause issues
            # such as https://github.com/pytorch/pytorch/issues/60733
            for i, output in enumerate(output_tensor_list):
                if torch.is_tensor(output) and output.grad_fn is None:
                    output_placeholders[i] = output

            passthrough_tensor_list = _DDPSink.apply(self.reducer, *output_tensor_list)
            for i in range(len(output_placeholders)):
                if output_placeholders[i] is None:
                    output_placeholders[i] = passthrough_tensor_list[i]

            # Reconstruct output data structure.
            output = _tree_unflatten_with_rref(
                output_placeholders, treespec, output_is_rref
            )
        return output

The logic to make sure the Reducer knows about the current graph so it can properly perform reduction is there

            if torch.is_grad_enabled() and self.require_backward_grad_sync:
                self.require_forward_param_sync = True
                # We'll return the output object verbatim since it is a freeform
                # object. We need to find any tensors in this object, though,
                # because we need to figure out which parameters were used during
                # this forward pass, to ensure we short circuit reduction for any
                # unused parameters. Only if `find_unused_parameters` is set.
                if self.find_unused_parameters and not self.static_graph:
                    # Do not need to populate this for static graph.
                    self.reducer.prepare_for_backward(list(_find_tensors(output)))
                else:
                    self.reducer.prepare_for_backward([])
            else:
                self.require_forward_param_sync = False

However, in manual optimization mode, you are performing backward and optimizer call directly within the forward function of your model within those lines and the reducer hasn't seen the output yet (traditionally the loss).

            if self.device_ids:
                inputs, kwargs = self.to_kwargs(inputs, kwargs, self.device_ids[0])
                output = self.module(*inputs[0], **kwargs[0])
            else:
                output = self.module(*inputs, **kwargs)

Therefore, if you call loss.backward and then optimizer.step, the gradients won't be reduced across processes and your model weights will start diverging.

However, self.manual_backward is internally providing the loss to the reducer, so reduction is done on backward call.

I hope it helps to clarify why you need to use self.manual_backward and where the extra time is being spent.

I would recommend to use manual optimization whenever you have special needs for your optimization which doesn't fit automatic optimisation.

Best, T.C

tchaton commented 3 years ago

Dear @pedrocolon93,

I will be closing this issue. Feel free to re-open if you have more questions.

Best, T.C

pedrocolon93 commented 3 years ago

Excellent!! Thanks for the clarification!