LinWeizheDragon / FLMR

The huggingface implementation of Fine-grained Late-interaction Multi-modal Retriever.
66 stars 4 forks source link

如何在下游任务中去使用PreFLMR模型? #12

Closed fengkangjie closed 3 months ago

fengkangjie commented 4 months ago

如题,当我有一份自定义的图片以及文档数据集时,我应该如何使用该数据集对PreFLMR的模型进行微调以达到我想要的通过图片来检索相关文档的目的?是否需要做图片以及文档的向量对齐的训练?期望能够解答我的疑惑,不甚感激。

LinWeizheDragon commented 4 months ago

You can finetune the PreFLMR model (using the provided forward function and the returned ib_loss) with your own data. The instructions can be the same as what we did for the WIT dataset. If you are sure that only Image->Text is required (you don't need Image+Text->Text), you can mask out the token embeddings from the text encoder entirely by passing in query_concat_output_from_text_encoder=False. This is similar to the setting we did in the first pre-training phase. Remember to freeze the visual encoder all the time.

fengkangjie commented 4 months ago

通常来讲,如果想要对PreFLMR模型进行微调以达到一个比较好的训练效果,使用的数据集的数据量建议是大于多少?

LinWeizheDragon commented 4 months ago

at least >10k for finetuning if the data distribution is not too far from the pretraining tasks (M2KR). If shifting to a completely new distribution (such as Chinese documents), you probably need >100k data to attain good generalisability.

fengkangjie commented 4 months ago

还有一个地方想咨询下,在测试Training with contrastive learning章节中的步骤时,query_tokenizer对应的question是 "Using.....: What is the capital of France?", "Extract ....: What is the capital of China?",对应context是["Paris is the capital of France.", "Beijing is the capital of China.","Paris is the capital of France.", "Beijing is the capital of China."],按照逻辑来说,对于这两个question,给出的两组context分别第1个和第2个是正确答案 但是我检查了下代码,flmr/models/flmr/modeling_flmr.py 844 labels结果都是0 image

也就是说第一个答案是正确答案,所以想问下页面中给出demo中的context顺序是否不太正确?

LinWeizheDragon commented 4 months ago

Hi, these two are different. The labels in the forward function are for training. In training, you always pass in batches like: query_input: query1, query2, query3, ... context_input: pos_item_for_q1, neg_item1_for_q1, neg_item2_for_q1, .... pos_item_for_q2, neg_item1_for_q2, neg_item2_for_q2, ...

in which case the first item is always the one you want the model to prioritise. Therefore the labels are always 0

The code is just for demonstration, which shows you the scores of different inputs. Since the first item and the second item are the correct answers for query1 and query2, you will see the model gives a higher score (lower loss) for the first item and the second item, respectively (already pre-trained).

fengkangjie commented 4 months ago

我准备了2.5w张特定领域的图片以及与之相对应的2.5w个中文文档,基于pytorch-lighting框架,通过对比学习来进行训练PreFLMR模型,使用的是PreFLMR_ViT-B预训练模型,learning-rate 1e-5,batch-size 10,epochs 5,得到的测试集的loss数据和score分数如下: image 发现测试集中的正例的分数和反例的分数并没有拉开差距,且都不高于40,想知道是否是我的训练方式出了问题。因为我刚刚开始接触这块,自身还有很多地方需要学习,如有叨扰,还请见谅! 训练代码如下所示: `class FlmrPLModel(pl.LightningModule):

def __init__(self, learning_rate=1e-3, save_path='.'):
    super().__init__()
    self.save_hyperparameters()
    self.save_path = save_path
    self.img_proc, self.query_tokenizer, self.context_tokenizer, self.flmr = initModel()

def forward(self, x):
    x = self.flmr.forward(x)
    return x

# 定义loss
def training_step(self, batch, batch_idx):
    query, img_paths, contexts = batch
    Q_encoding = self.query_tokenizer(query)
    imgs = []
    for img_path in img_paths:
        img = Image.open(img_path).convert("RGB")
        imgs.append(img)
    Q_pixel_values = self.img_proc(imgs, return_tensors="pt")['pixel_values']
    ctxs = []
    for ctx in contexts:
        ctxs.extend(ctx)
    D_encoding = self.context_tokenizer(ctxs)
    inputs = dict(
        query_input_ids=Q_encoding['input_ids'],
        query_attention_mask=Q_encoding['attention_mask'],
        query_pixel_values=Q_pixel_values,
        context_input_ids=D_encoding['input_ids'],
        context_attention_mask=D_encoding['attention_mask'],
        use_in_batch_negatives=True,
        query_concat_output_from_text_encoder=False,
        num_negative_examples=4
    )

    res = self.flmr.forward(**inputs)
    loss = res['loss']
    scores = res['scores']
    print('loss=', loss, ';scores=', scores)
    return {"loss": loss, "scores": scores.detach()}

def training_step_end(self, outputs):
    return {"loss": outputs["loss"].mean()}

def configure_optimizers(self):
    return torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate, eps=1e-08)

def test_step(self, batch, batch_idx):
    query, img_paths, contexts = batch
    Q_encoding = self.query_tokenizer(query)
    imgs = []
    for img_path in img_paths:
        img = Image.open(img_path).convert("RGB")
        imgs.append(img)
    Q_pixel_values = self.img_proc(imgs, return_tensors="pt")['pixel_values']
    ctxs = []
    for ctx in contexts:
        ctxs.extend(ctx)
    D_encoding = self.context_tokenizer(ctxs)
    inputs = dict(
        query_input_ids=Q_encoding['input_ids'],
        query_attention_mask=Q_encoding['attention_mask'],
        query_pixel_values=Q_pixel_values,
        context_input_ids=D_encoding['input_ids'],
        context_attention_mask=D_encoding['attention_mask'],
        use_in_batch_negatives=True,
        num_negative_examples=4
    )
    res = self.flmr.forward(**inputs)
    loss = res['loss']
    scores = res['scores']
    print("test loss=", loss, ";scores=", scores)
    return {"loss": loss, "scores": scores.detach()}

def test_step_end(self, outputs):
    self.log("test_loss", outputs["loss"].mean(), on_epoch=True, on_step=False)

def on_test_epoch_end(self):
    print('Saving model in the Huggingface format...')
    path_save_model = os.path.join(self.save_path, 'step_{}'.format(self.global_step))
    self.flmr.save_pretrained(path_save_model)
    print(('Model has been saved to {}'.format(path_save_model)))

def add_model_args(parent_parser):
    parser = ArgumentParser(parents=[parent_parser], add_help=False)
    parser.add_argument('--learning_rate', type=float, default=1e-5)
    parser.add_argument('--save_path', type=str, default='.')
    parser.add_argument('--gpus', type=int, default=1)
    parser.add_argument('--max_epochs', type=int, default=10)
    return parser

def main(hparams): pl.seed_everything(0) data_mnist = FLMRDataModule(data_dir=hparams.data_dir, batch_size=hparams.batch_size, num_workers=hparams.num_workers) model = FlmrPLModel(learning_rate=hparams.learning_rate, save_path=hparams.save_path) trainer = Trainer(max_epochs=hparams.max_epochs, callbacks=[])

trainer.fit(model, data_mnist)
model.eval()
result = trainer.test(model, data_mnist, ckpt_path='best')
print(result)

if name == "main": parser = ArgumentParser() parser = FLMRDataModule.add_dataset_args(parser) parser = FlmrPLModel.add_model_args(parser) hparams = parser.parse_args() main(hparams)`

LinWeizheDragon commented 4 months ago

Hi, could you please share maybe the data in one batch, including the query (should be instruction in your case) and corresponding pos/neg documents? I would understand your case better.

Since you are training on a Chinese corpus, you will need more epochs and steps. Below is the curves of PreFLMR-L on another corpus with Chinese query and context, for your reference. I suggest using the L model instead of B to get a better performance.

image

LinWeizheDragon commented 4 months ago

batch size=8 grad_accum=8 negatives=4, trained for 200k steps You should freeze the ViT encoder if you haven't done so!! See the paper for the result when unfreeze the ViT encoder

LinWeizheDragon commented 4 months ago

The test loss is still high in your case. Normally you get a good performance when the test loss goes below 0.1-0.2

LinWeizheDragon commented 4 months ago

Re your code, you should use

res = self.flmr.forward(**inputs)
loss = res['ib_loss']

To use in-batch negative loss instead of direct contrastive loss.

And you need to keep the train_step and test_step consistent in the input. I found that you enabled query_concat_output_from_text_encoder=False, in the training loss but not in the test loop, which may be the reason why you got a high test loss at the end.

fengkangjie commented 4 months ago

方便再问下,微调PreFLMR_ViT-L模型,对于训练用的单个图片像素大小以及单个文档长度有限制吗?

LinWeizheDragon commented 4 months ago

The visual encoder of PreFLMR is ViT. The input image is processed by the ViT's image processor, which means >256 resolution is preferred.

The context encoder is a BERT model, and thus the length should be less than 512 tokens. In practice, you can split your documents so that each of them is less than 500 tokens.

LinWeizheDragon commented 4 months ago

Hi, just to let you know that a finetuning script is now available at https://github.com/LinWeizheDragon/FLMR?tab=readme-ov-file#new-finetune-the-preflmr-model-on-downstream-datasets

fengkangjie commented 4 months ago

Hi, just to let you know that a finetuning script is now available at https://github.com/LinWeizheDragon/FLMR?tab=readme-ov-file#new-finetune-the-preflmr-model-on-downstream-datasets

Thank you very, very much.