uber-research / PPLM

Plug and Play Language Model implementation. Allows to steer topic and attributes of GPT-2 models.
Apache License 2.0
1.13k stars 202 forks source link

Question about how to handle `inputs_embeds` #28

Open forest1988 opened 4 years ago

forest1988 commented 4 years ago

Hi,

Thank you for sharing your great work!

I have a question about how to handle input_embeds in the PPLM code.

When I look run_pplm.py, I found something I cannot understand the intention. https://github.com/uber-research/PPLM/blob/master/run_pplm.py#L220

        if loss_type == PPLM_DISCRIM or loss_type == PPLM_BOW_DISCRIM:
            ce_loss = torch.nn.CrossEntropyLoss()
            # TODO why we need to do this assignment and not just using unpert_past? (Sumanth)
            curr_unpert_past = unpert_past
            curr_probs = torch.unsqueeze(probs, dim=1)
            wte = model.resize_token_embeddings()
            for _ in range(horizon_length):
                inputs_embeds = torch.matmul(curr_probs, wte.weight.data)
                _, curr_unpert_past, curr_all_hidden = model(
                    past=curr_unpert_past,
                    inputs_embeds=inputs_embeds
                )
                curr_hidden = curr_all_hidden[-1]
                new_accumulated_hidden = new_accumulated_hidden + torch.sum(
                    curr_hidden, dim=1)

inputs_embeds is updated in the for loop, but curr_probs and wte.weight.data seem not to be updated in the loop.

Could you please tell me the reason inputs_embeds is calculated in the for loop?

Thank you in advance!

dathath commented 4 years ago

I think that is actually a bug, and might also explain why our experiments with horizon_length > 1 did not work so well (we use horizon-length=1 in all of our experiments). If you're running with horizon-length=1, it shouldn't matter but that is a bug. We do need to update curr_probs inside the loop here.

forest1988 commented 3 years ago

Excuse me for my reply was so delayed. I've missed the notification that this issue was kindly responded to by you.

Thank you for the detailed information. Now I understood that what I mentioned is a bug and you used horizon-lenghth=1 for your experiments to make it do not matter.

Updating curr_probs inside the loop here is needed, I understood. I'm not a good programmer, but if I can do somewhat helpful to you, please tell me.

Again, I apoligize for my late reply.