junyangwang0410 / Knight

SotA text-only image/video method (IJCAI 2023)
10 stars 0 forks source link

请教一下generate()方法中prefix参数。 #4

Open kk-dark opened 2 months ago

kk-dark commented 2 months ago
def caption_generation(image_feature, model: GPT2LMHeadModel, tokenizer, device):
    text = "prefix prefix prefix prefix prefix:"
    inputs = tokenizer(text, return_tensors="pt")
    output = model.generate(inputs["input_ids"].to(device), 40, prefix = image_feature, do_sample = False, num_beams=5)[0]
    output = tokenizer.decode(output)
    return output.split(':')[1].split('.')[0].lower()

如上这段代码model.generate()方法中用到了一个prefix参数,我在查阅Huggingface的文档中并没有找到关于prefix参数的解释。

在modeling_gpt2.py文件中,我找到了如下部分代码:

def forward(
        ...
        prefix: Optional[torch.FloatTensor] = None,
    ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
        ...

以及:

...
if inputs_embeds is None:
    inputs_embeds = self.wte(input_ids)
if prefix != None:
    prefix = prefix.expand(inputs_embeds.shape[0], 5, inputs_embeds.shape[2])
    inputs_embeds = torch.cat((prefix, inputs_embeds[:, 5:, :]), dim = 1)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
...

这段部分的添加应该是作者的修改对吗?期待您的回复。

junyangwang0410 commented 2 months ago

是的。正如你所说,我们修改了transformers的代码来完成目的。

kk-dark commented 2 months ago

是的。正如你所说,我们修改了transformers的代码来完成目的。

感谢您的回复!