Open kk-dark opened 2 months ago
def caption_generation(image_feature, model: GPT2LMHeadModel, tokenizer, device): text = "prefix prefix prefix prefix prefix:" inputs = tokenizer(text, return_tensors="pt") output = model.generate(inputs["input_ids"].to(device), 40, prefix = image_feature, do_sample = False, num_beams=5)[0] output = tokenizer.decode(output) return output.split(':')[1].split('.')[0].lower()
如上这段代码model.generate()方法中用到了一个prefix参数,我在查阅Huggingface的文档中并没有找到关于prefix参数的解释。
在modeling_gpt2.py文件中,我找到了如下部分代码:
def forward( ... prefix: Optional[torch.FloatTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: ...
以及:
... if inputs_embeds is None: inputs_embeds = self.wte(input_ids) if prefix != None: prefix = prefix.expand(inputs_embeds.shape[0], 5, inputs_embeds.shape[2]) inputs_embeds = torch.cat((prefix, inputs_embeds[:, 5:, :]), dim = 1) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds ...
这段部分的添加应该是作者的修改对吗?期待您的回复。
是的。正如你所说,我们修改了transformers的代码来完成目的。
感谢您的回复!
如上这段代码model.generate()方法中用到了一个prefix参数,我在查阅Huggingface的文档中并没有找到关于prefix参数的解释。
在modeling_gpt2.py文件中,我找到了如下部分代码:
以及:
这段部分的添加应该是作者的修改对吗?期待您的回复。