LinWeizheDragon / FLMR

The huggingface implementation of Fine-grained Late-interaction Multi-modal Retriever.
42 stars 2 forks source link

question about details of finetuning script #16

Closed Maxlinn closed 3 weeks ago

Maxlinn commented 1 month ago

hi lin, i managed to write a finetuning script, could you help me check it? i also got confused about some details, listed below(also marked with NOTE in code comments), could you illustrate somehow? thanks!

  1. in preflmr finetuning (paper B.2.), infoseek was finetuned on 4 gpus with batch size 8 and gradient accumulation step 8, thus batch size per step is 4 * 8 * 8 = 256. and infoseek was finetuned on 1k steps, adds up to 256 * 1k = 256k examples. however in the m2kr train datasheet, infoseek has 100k examples (in hf repo it is 600k actually). is the 256k examples is made up mutiple epochs of 100k examples, or sampled from 600k?
  2. in training, an example in train dataset has mulitple positive passages (stored in pos_item_contents), is it sample by random from pos_item_contents in dataset.__getitem__?
  3. in training, an example needs 4 negatives passages, are those sampled by random from non-pos passages in knowledge base?
  4. in collate_fn, in_batch_negatives_from_all_gpus should be True or False (by default it is False).
import transformers
from transformers import TrainingArguments, Trainer, HfArgumentParser
from transformers import AutoImageProcessor
from datasets import load_dataset

import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch.utils.data import Dataset

import random
import os
from PIL import Image
from pprint import pformat
from dataclasses import dataclass

from flmr import FLMRQueryEncoderTokenizer, FLMRContextEncoderTokenizer, FLMRModelForRetrieval

@dataclass
class MyArguments:
    model_name_or_path :str = "LinWeizheDragon/PreFLMR_ViT-G"
    image_processor_name :str = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
    dataset_name :str = "BByrneLab/multi_task_multi_modal_knowledge_retrieval_benchmark_M2KR"
    dataset_subset_name :str = "Infoseek"
    query_images_dir :str = "./Infoseek/train_images"
    num_negative_examples :int = 4

@dataclass
class PreFLMRTrainingArguments(TrainingArguments):
    # set according to PreFLMR paper 
    remove_unused_columns :bool = False
    per_device_train_batch_size :int = 8
    gradient_accumulation_steps :int = 8
    logging_steps :int = 1
    eval_strategy :str = 'no'
    save_strategy :str = 'steps'
    save_steps :int = 500
    max_steps :int = 1000
    save_only_model :bool = True
    save_total_limit :int = 5
    seed :int = 42
    # manually set optimizer later
    mapping_structure_lr :float = 1e-4
    non_mapping_structure_lr :float = 1e-5

class PreFLMRTrainer(Trainer):

    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(**inputs, return_dict=True)

        # replace with ib loss
        ib_loss = outputs["in_batch_negative_loss"]
        outputs["deprecated_loss"] = outputs["loss"]
        outputs["loss"] = ib_loss

        return (ib_loss, outputs) if return_outputs else ib_loss

class PreFLMRDataset(Dataset):

    def __init__(self,
                 args,
                 data_df, passages_df, 
                 query_tokenizer, context_tokenizer, image_processor):
        self.args = args
        self.data_df = data_df
        self.passages_df = passages_df
        self.query_tokenizer = query_tokenizer
        self.context_tokenizer = context_tokenizer
        self.image_processor = image_processor

        self.unique_passage_ids = set(self.passages_df.index)

    def __len__(self):
        return len(self.data_df)

    def __getitem__(self, idx):
        row = self.data_df.iloc[idx]
        # NOTE: *infoseek* subset happends to have instructions prepended to `question`.
        #   for other subsets, instructions are required.
        # query = instruction + row['question']
        assert ':' in row['question'], 'Only Infoseek has instruction prepended to question in `question` field.'
        query = row['question']

        pos_item_ids = row['pos_item_ids']
        pos_item_id = random.choice(pos_item_ids) # NOTE random choose here?
        pos_passage = self.passages_df.loc[pos_item_id]['passage_content']

        neg_item_ids = random.sample(list(self.unique_passage_ids - set(pos_item_ids)), 
                                     self.args.num_negative_examples)  # NOTE random choose here?
        neg_passages = [self.passages_df.loc[neg_item_id]['passage_content'] for neg_item_id in neg_item_ids]

        image_path = os.path.join(self.args.query_images_dir, row['img_path'])
        image = Image.open(image_path).convert('RGB')
        pixel_values = self.image_processor(image, return_tensors='pt')['pixel_values'] # [1, 3, 224, 224]

        return dict(
            query=query,
            pos_passage=pos_passage,
            neg_passages=neg_passages,
            pixel_values=pixel_values
        )

    def collate_fn(self, batch):
        queries = [ex['query'] for ex in batch]
        passages = [] # [pos, neg, neg, neg, pos, ...]
        for ex in batch:
            passages.append(ex['pos_passage'])
            passages.extend(ex['neg_passages'])

        Q_encoding = self.query_tokenizer(queries)
        Q_pixel_values = torch.cat([ex['pixel_values'] for ex in batch], dim=0)
        D_encoding = self.context_tokenizer(passages)

        # according to `modeling_flmr.py, FLMRModelForRetrieval.forward`
        inputs = dict(
            query_input_ids=Q_encoding['input_ids'],
            query_attention_mask=Q_encoding['attention_mask'],
            query_pixel_values=Q_pixel_values,
            context_input_ids=D_encoding['input_ids'],
            context_attention_mask=D_encoding['attention_mask'],
            use_in_batch_negatives=True,
            in_batch_negatives_from_all_gpus=False, # NOTE should be False here?
            num_negative_examples=self.args.num_negative_examples
        )
        return inputs

def main():
    parser = HfArgumentParser((MyArguments, PreFLMRTrainingArguments))
    my_args, training_args = parser.parse_args_into_dataclasses()

    ## setting up
    assert dist.get_world_size() == 4, 'The paper uses 4 gpus.'
    if dist.get_rank() == 0:
        print('## my_args: ', pformat(my_args))
        print('## training_args: ', pformat(training_args))

    ## setting up tokenizer
    query_tokenizer = FLMRQueryEncoderTokenizer.from_pretrained(
        my_args.model_name_or_path, subfolder="query_tokenizer")
    context_tokenizer = FLMRContextEncoderTokenizer.from_pretrained(
        my_args.model_name_or_path, subfolder="context_tokenizer")
    image_processor = AutoImageProcessor.from_pretrained(my_args.image_processor_name)

    ## setting up dataset
    data_df = load_dataset(my_args.dataset_name, f'{my_args.dataset_subset_name}_data')['train']\
        .to_pandas().set_index('question_id')
    passages_df = load_dataset(my_args.dataset_name, f'{my_args.dataset_subset_name}_passages')['train_passages']\
        .to_pandas().set_index('passage_id')

    dataset = PreFLMRDataset(args=my_args,
                             data_df=data_df, passages_df=passages_df,
                             query_tokenizer=query_tokenizer, 
                             context_tokenizer=context_tokenizer, 
                             image_processor=image_processor)

    if dist.get_rank() == 0:
        print(pformat(f'## dataset[0]: {dataset[0]}'))

    ## setting up model
    model = FLMRModelForRetrieval.from_pretrained(
        my_args.model_name_or_path,
        query_tokenizer=query_tokenizer,
        context_tokenizer=context_tokenizer)

    ## setting up training
    # build trainables    
    # PreFLMR consists of
    #   pretrained_structure_modules: pretrained text and vision encoder, remain frozen
    #   mapping_structure_modules: a 2-layer MLP_F^MLP and a Transformer block F_M^TR, lr = 1e-4
    #   non_mapping_structure_modules: remaining modules, mostly linears. lr = 1e-5
    pretrained_structure_modules = [
        model.query_text_encoder, # FLMRTextModel
        model.query_vision_encoder, # FLMRVisionModel
        model.context_text_encoder, # FLMRTextModel
        model.context_vision_encoder # FLMRVisionModel
    ]
    mapping_structure_modules = [
        model.query_vision_projection, # FLMRMultiLayerPerceptron
        model.context_vision_projection, # FLMRMultiLayerPerceptron
        model.transformer_mapping_network, # BertEncoder
    ]
    non_mapping_structure_modules = [
        model.query_text_encoder_linear, # Linear
        model.context_text_encoder_linear, # Linear
        model.transformer_mapping_input_linear, # Linear
        model.transformer_mapping_output_linear, # Linear
    ]
    # check included all paramters, nothing left
    assert set(id(p) for p in model.parameters()) == set(id(p) \
        for module in pretrained_structure_modules + mapping_structure_modules + non_mapping_structure_modules
        for p in module.parameters())

    for module in pretrained_structure_modules:
        for p in module.parameters():
            p.requires_grad = False

    if dist.get_rank() == 0:
        trainables = [pn for pn, p in model.named_parameters() if p.requires_grad]
        n_trainables = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(pformat(f'## trainables: {trainables}'))
        print(pformat(f'## n_trainables: {n_trainables:,}'))

    # build optimizer and constant scheduler
    #   need to deduplicate modules (some modules may share parameters)
    optimizer_groups = []

    mapping_structure_modules_dedup = list({id(m) : m for m in mapping_structure_modules}.values())
    for module in mapping_structure_modules_dedup:
        optimizer_groups.append(dict(params=module.parameters(), lr=training_args.mapping_structure_lr))

    non_mapping_structure_modules_dedup = list({id(m) : m for m in (non_mapping_structure_modules)}.values())
    for module in non_mapping_structure_modules_dedup:
        optimizer_groups.append(dict(params=module.parameters(), lr=training_args.non_mapping_structure_lr))

    optimizer = torch.optim.Adam(optimizer_groups)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 1.0)

    ## start training
    trainer = PreFLMRTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        data_collator=dataset.collate_fn,
        optimizers=[optimizer, scheduler]
    )
    trainer.train()

    trainer.save_state()
    trainer.save_model()
    query_tokenizer.save_pretrained(os.path.join(training_args.output_dir, 'query_tokenizer'))
    context_tokenizer.save_pretrained(os.path.join(training_args.output_dir, 'context_tokenizer'))

if __name__ == '__main__':
    main()

thanks in advance for your generous help!

LinWeizheDragon commented 1 month ago

Hi,

Our team is less busy recently and I just got some time to make some updates to PreFLMR:

  1. Fixed some issues in the M2KR benchmark and add instructions to each entry (should not affect your case too much)
  2. Fixed the evaluation script to make sure all results reported in the paper can be reproduced
  3. added a finetuning script which can finetune PreFLMR-G to achieve 74% PRecall@5 on E-VQA

We plan to release these changes in a few days. At that time you can cross-check your script with ours.

LinWeizheDragon commented 1 month ago

Re your question:

  1. Yes, you can run more than one epoch - but note that Infoseek's data distribution is not similar in the train and the validation set. Training may lead to significant overfitting. In our work, we just trained it for a fixed number of steps. You can adjust the learning rate (e.g. a lower one like 5e-6) and warmup steps to obtain a more steady curve in training. Specifically in Infoseek, you may want a small save interval to capture the best model.

  2. Yes, randomly sampled

  3. Yes, randomly sampled from the whole training passage corpus (ofc excluding those in pos_item_ids)

  4. We did not observe performance improvement by setting it to True, but I will leave it to you. You can try it out if you have multiple GPUs.

Maxlinn commented 1 month ago

hearfully thanks for your timely reply!

LinWeizheDragon commented 1 month ago

Hi, just to let you know that a finetuning script is now available at https://github.com/LinWeizheDragon/FLMR?tab=readme-ov-file#new-finetune-the-preflmr-model-on-downstream-datasets

Maxlinn commented 1 month ago

thanks for noticing! i've tried the released finetuning script, however i find that after finetuning on infoseek, there seems slight degradation.

the finetuning script is mostly the same as example_finetune_preflmr.py, with few changes (can be found in comments marked with :add: tag):

  1. limit the train dataset to 100k examples (randomly selected), since in paper it was 100k.
  2. disabling validation since there is no val set for infoseek, by setting val dataset to trivial examples.
  3. setting max_steps to 1000 and trigger saving model, since in paper it was 1k steps finetuning on infoseek.
  4. use 4 gpus, since in paper it was 4 A100.

i noticed 1k steps on 100k is more than 2 epochs, about 2.5 epochs finally.

Epoch 2:  56%|█████▌    | 1744/3125 [08:36<06:49,  3.38it/s, v_num=6ter, train/loss_step=0.00174, train/lr[0]=5.56e-9, train/loss_epoch=0.0244]

could you throw some light on how can i reproduce the numbers on paper? much thanks!


zeroshot test results on infossek

Total number of questions: 4708
Pseudo Recall@1:     0.30713678844519965
Pseudo Recall@5:     0.5632965165675446
Pseudo Recall@10:    0.6703483432455395
Pseudo Recall@20:    0.7627442650807137
Pseudo Recall@50:    0.8540781648258283
Pseudo Recall@100:   0.8920985556499575
Recall@1:    0.18011894647408666
Recall@5:    0.3734069668649108
Recall@10:   0.4607051826677995
Recall@20:   0.5482158028887001
Recall@50:   0.6675870858113849
Recall@100:  0.7327952421410365

after finetuning 1000 steps it is

Total number of questions: 4708
Pseudo Recall@1:     0.2941801189464741
Pseudo Recall@5:     0.5686066270178419
Pseudo Recall@10:    0.6799065420560748
Pseudo Recall@20:    0.7735768903993203
Pseudo Recall@50:    0.8666100254885302
Pseudo Recall@100:   0.9022939677145284
Recall@1:    0.1745964316057774
Recall@5:    0.37616822429906543
Recall@10:   0.47344944774851316
Recall@20:   0.5694562446898895
Recall@50:   0.6964740866610025
Recall@100:  0.7657179269328802

python preflmr_finetune.py \
    --image_root_dir "Infoseek/train_images" \
    --dataset_hf_path "multi_task_multi_modal_knowledge_retrieval_benchmark_M2KR" \
    --dataset "Infoseek" \
    --freeze_vit \
    --log_with_wandb \
    --model_save_path "${RUN_DIR}" \
    --checkpoint_path "PreFLMR_ViT-G" \
    --image_processor_name "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" \
    --batch_size 8 \
    --accumulate_grad_batches 8 \
    --valid_batch_size 16 \
    --test_batch_size 64 \
    --mode train \
    --max_steps 1000 \
    --learning_rate 0.000005 \
    --warmup_steps 100 \
    --accelerator auto \
    --devices auto \
    --strategy ddp_find_unused_parameters_true \
    --num_sanity_val_steps 2 \
    --precision bf16 \
    --val_check_interval 2000 \
    --save_top_k -1 \
    --devices "4,5,6,7" \
    --wandb_project "${RUN_NAME}" \
    --make_val_trivial \
    --sample_train_examples 100000
import os
import torch
from torch.utils.data import DataLoader
from pytorch_lightning import LightningModule, Trainer, Callback
from pytorch_lightning.callbacks import ModelCheckpoint
from datasets import load_dataset, DatasetDict
from transformers import set_seed, AutoImageProcessor
from PIL import Image
import argparse
import random
from easydict import EasyDict
import numpy as np
import shutil
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.loggers import WandbLogger
from functools import partial

from flmr import FLMRQueryEncoderTokenizer, FLMRContextEncoderTokenizer, FLMRModelForRetrieval

class RetrievalModel(LightningModule):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.save_hyperparameters()

        self.checkpoint_path = self.args.checkpoint_path
        self.query_tokenizer = FLMRQueryEncoderTokenizer.from_pretrained(self.checkpoint_path, subfolder="query_tokenizer")
        self.context_tokenizer = FLMRContextEncoderTokenizer.from_pretrained(self.checkpoint_path, subfolder="context_tokenizer")
        self.image_processor = AutoImageProcessor.from_pretrained(self.args.image_processor_name)

        # Load and prepare datasets
        self.prepare_datasets()

        self.train_dataloader()

        self.model = FLMRModelForRetrieval.from_pretrained(self.checkpoint_path,
                                                           query_tokenizer=self.query_tokenizer,
                                                           context_tokenizer=self.context_tokenizer)

        if self.args.freeze_vit:
            # freeze parameters of query_encoder and context_encoder
            for name, param in self.model.query_vision_encoder.named_parameters():
                param.requires_grad = False
            for name, param in self.model.context_vision_encoder.named_parameters():
                param.requires_grad = False

    def prepare_datasets(self):
        self.dataset = load_dataset(self.args.dataset_hf_path, self.args.dataset + "_data")
        self.passage_ds = load_dataset(self.args.dataset_hf_path, self.args.dataset + "_passages")

        # :add:
        if self.args.sample_train_examples is not None:
            print(f"## sample_train_examples ({self.args.sample_train_examples})...")
            self.dataset['train'] = self.dataset['train'].shuffle(seed=42).select(range(self.args.sample_train_examples))

        def add_path_prefix_in_img_path(example, prefix):
            example["img_path"] = os.path.join(prefix, example["img_path"])
            return example

        self.dataset = self.dataset.map(add_path_prefix_in_img_path, fn_kwargs={"prefix": self.args.image_root_dir})

        instructions = [
            "Using the provided image, obtain documents that address the subsequent question: ",
            "Retrieve documents that provide an answer to the question alongside the image: ",
            "Extract documents linked to the question provided in conjunction with the image: ",
            "Utilizing the given image, obtain documents that respond to the following question: ",
            "Using the given image, access documents that provide insights into the following question: ",
            "Obtain documents that correspond to the inquiry alongside the provided image: ",
            "With the provided image, gather documents that offer a solution to the question: ",
            "Utilizing the given image, obtain documents that respond to the following question: ",
        ]

        def prepare_inputs(sample):
            sample = EasyDict(sample)

            random_instruction = random.choice(instructions)
            input_text_sequence = " ".join(
                [random_instruction]
                + [sample.question]
            )

            sample["input_text_sequence"] = input_text_sequence

            return sample

        self.dataset = self.dataset.map(prepare_inputs)

        print(self.dataset['train'][0])

        # Tokenize and prepare image pixels for input
        # ds = ds.map(
        #     tokenize_inputs,
        #     fn_kwargs={"query_tokenizer": self.query_tokenizer, "context_tokenizer": self.context_tokenizer, "image_processor": self.image_processor},
        #     batched=True,
        #     batch_size=8,
        #     num_proc=16,
        # )

    def collate_fn(self, batch, passage_split="train_passages"):

        batch_data = {}

        input_text_sequences = [example['input_text_sequence'] for example in batch]
        encoding = self.query_tokenizer(input_text_sequences)
        query_input_ids = encoding["input_ids"]
        query_attention_mask = encoding["attention_mask"]

        img_paths = [example['img_path'] for example in batch]
        pixel_values = []
        for img_path in img_paths:
            image = Image.open(img_path).convert("RGB")
            encoded = self.image_processor(image, return_tensors="pt")
            pixel_values.append(encoded.pixel_values)

        pixel_values = torch.stack(pixel_values, dim=0)

        num_negative_examples = self.args.num_negative_examples

        def negative_sampling(pos_item_ids, num_samples=1):
            """Generate negative samples for a query. ONLY used in training
            Args:
                user_item (int tensor): user id
                num_samples (int, optional): number of samples. Defaults to 1.
            Returns:
                neg_items: list of negative item ids.
            """
            neg_items = []

            while len(neg_items) < num_samples:
                # sample num_samples negative items for the user
                while True:
                    neg_item = np.random.randint(low=0, high=len(self.passage_ds), size=1)[0]

                    VALID = True
                    neg_item = self.passage_ds[passage_split][int(neg_item)]
                    if neg_item['passage_id'] in pos_item_ids:
                        VALID = False

                    if VALID == True:
                        break
                neg_items.append(neg_item)
            return neg_items

        batched_context_input_sequences = []

        for example in batch:
            select_pos_item_index = random.sample(range(len(example['pos_item_ids'])), k=1)[0]
            pos_item_id = example['pos_item_ids'][select_pos_item_index]
            pos_item_content = example['pos_item_contents'][select_pos_item_index]

            batched_context_input_sequences.append(pos_item_content)

            neg_items = negative_sampling(pos_item_id, num_samples=num_negative_examples)
            neg_item_ids = [item['passage_id'] for item in neg_items]
            neg_item_contents = [item['passage_content'] for item in neg_items]

            batched_context_input_sequences.extend(neg_item_contents)

        context_encoding = self.context_tokenizer(batched_context_input_sequences)
        context_input_ids = context_encoding["input_ids"]
        context_attention_mask = context_encoding["attention_mask"]

        batch_data.update(dict(
            query_input_ids=query_input_ids,
            query_attention_mask=query_attention_mask,
            query_pixel_values=pixel_values,
            context_input_ids=context_input_ids,
            context_attention_mask=context_attention_mask,
        ))
        # print(query_input_ids.shape)
        # print(query_attention_mask.shape)
        # print(pixel_values.shape)
        # print(context_input_ids.shape)
        # print(context_attention_mask.shape)
        return batch_data

    def train_dataloader(self):
        # Create a partial function with parameters
        parametrized_collate_fn = partial(self.collate_fn, passage_split="train_passages")

        dataloader = DataLoader(
            self.dataset['train'], 
            batch_size=self.args.batch_size, 
            shuffle=True, 
            collate_fn=parametrized_collate_fn,
            num_workers=4,
        )
        return dataloader

    def val_dataloader(self):
        # :add:
        if self.args.make_val_trivial:
            print("## make_val_trivial working...")
            # Create a partial function with parameters
            parametrized_collate_fn = partial(self.collate_fn, passage_split="train_passages")
            dataloader = DataLoader(
                self.dataset['train'].select(range(1)), 
                batch_size=self.args.valid_batch_size, 
                collate_fn=parametrized_collate_fn,
            )
            return dataloader

        # Create a partial function with parameters
        parametrized_collate_fn = partial(self.collate_fn, passage_split="valid_passages")

        dataloader = DataLoader(
            self.dataset['valid'], 
            batch_size=self.args.valid_batch_size, 
            collate_fn=parametrized_collate_fn,
            num_workers=2,
        )
        return dataloader

    def test_dataloader(self):
        # Create a partial function with parameters
        parametrized_collate_fn = partial(self.collate_fn, passage_split="test_passages")

        dataloader = DataLoader(
            self.dataset['test'], 
            batch_size=self.args.test_batch_size, 
            collate_fn=parametrized_collate_fn,
            num_workers=2,
        )
        return dataloader

    def forward(self, batch):
        batch = {
            k: v.to(self.device) for k,v in batch.items()
        }
        # Prepare inputs for model
        inputs = {
            'query_input_ids': batch['query_input_ids'],
            'query_attention_mask': batch['query_attention_mask'],
            'query_pixel_values': batch['query_pixel_values'],
            'context_input_ids': batch['context_input_ids'],
            'context_attention_mask': batch['context_attention_mask'],
            'use_in_batch_negatives': True,
            "num_negative_examples": self.args.num_negative_examples,
        }

        return self.model(**inputs)

    def training_step(self, batch, batch_idx):
        outputs = self.forward(batch)
        loss = outputs.loss
        self.log('train/loss', loss, prog_bar=True, on_step=True, on_epoch=True, logger=True, sync_dist=True)

        current_lrs = self.scheduler.get_last_lr()
        for index, current_lr in enumerate(current_lrs):
            self.log(f"train/lr[{index}]", current_lr, prog_bar=True, on_step=True, logger=True, sync_dist=True)

        return loss

    def validation_step(self, batch, batch_idx):
        outputs = self.forward(batch)
        loss = outputs.loss
        self.log('valid/loss', loss, on_step=True, on_epoch=True, logger=True, sync_dist=True)
        return loss

    def test_step(self, batch, batch_idx):
        outputs = self.forward(batch)
        loss = outputs.loss
        self.log('test/loss', loss, on_step=True, on_epoch=True, logger=True, sync_dist=True)
        return loss

    def configure_optimizers(self):
        self.optimizer = torch.optim.AdamW(self.parameters(), lr=self.args.learning_rate)

        from transformers import get_scheduler
        self.scheduler = get_scheduler(
            "linear",
            optimizer=self.optimizer,
            num_warmup_steps=self.args.warmup_steps,
            num_training_steps=self.trainer.estimated_stepping_batches,
        )

        return {
            "optimizer": self.optimizer,
            "lr_scheduler": {
                "scheduler": self.scheduler,
                "interval": "step",
                "frequency": 1
            }
        }

class ModelSaveCallback(Callback):
    def __init__(self, save_path, save_top_k=3):
        self.save_path = save_path
        self.best_models = []
        self.save_top_k = save_top_k

    @rank_zero_only
    def on_validation_epoch_end(self, trainer, pl_module):
        if trainer.state.stage in ['sanity_check']:
            return
        current_loss = trainer.callback_metrics["valid/loss"].item()
        current_step = trainer.global_step
        model_name = f"model_step_{current_step}"
        model_path = os.path.join(self.save_path, model_name)

        if self.save_top_k == -1:
            # save all models
            pl_module.model.save_pretrained(model_path)
            pl_module.query_tokenizer.save_pretrained(os.path.join(model_path, "query_tokenizer"))
            pl_module.context_tokenizer.save_pretrained(os.path.join(model_path, "context_tokenizer"))
            print(f"\nThe metric is {current_loss}, save_top_k=-1. Saving {model_path}")
            return

        if len(self.best_models) < self.save_top_k or current_loss < max(self.best_models, key=lambda x: x[1])[1]:
            print(f"\nThe metric is {current_loss}, at the top {self.save_top_k}. Saving {model_path}")

            self.best_models.append((model_path, current_loss))
            self.best_models.sort(key=lambda x: x[1])

            if len(self.best_models) > self.save_top_k:
                removed_model = self.best_models.pop()
                print("\nRemoving", removed_model[0])
                try:
                    shutil.rmtree(removed_model[0], ignore_errors=True)
                except Exception as e:
                    print(f"\nRemove failed. The file may have been removed. \nError: {e}")

            pl_module.model.save_pretrained(model_path)
            pl_module.query_tokenizer.save_pretrained(os.path.join(model_path, "query_tokenizer"))
            pl_module.context_tokenizer.save_pretrained(os.path.join(model_path, "context_tokenizer"))

        else:
            print(f"\nThe current metric is {current_loss}, not at the top {self.save_top_k}")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--valid_batch_size", type=int, default=64)
    parser.add_argument("--test_batch_size", type=int, default=64)
    parser.add_argument("--warmup_steps", type=int, default=0)
    parser.add_argument("--save_top_k", type=int, default=-1)
    parser.add_argument("--learning_rate", type=float, default=1e-5)
    parser.add_argument("--dataset_hf_path", type=str, required=True)
    parser.add_argument("--dataset", type=str, required=True)
    parser.add_argument("--mode", type=str, required=True)
    parser.add_argument("--log_with_wandb", action="store_true")
    parser.add_argument("--wandb_project", type=str, default="finetune_preflmr")
    parser.add_argument("--image_root_dir", type=str, required=True)
    parser.add_argument("--image_processor_name", type=str, default="openai/clip-vit-large-patch14")
    parser.add_argument("--checkpoint_path", type=str, default="LinWeizheDragon/PreFLMR_ViT-L", required=True)
    parser.add_argument("--num_negative_examples", type=int, default=4)
    parser.add_argument("--freeze_vit", action="store_true")
    parser.add_argument("--model_save_path", type=str, default="saved_models")
    # :add:
    parser.add_argument("--make_val_trivial", action='store_true')
    parser.add_argument("--sample_train_examples", type=int, default=None)

    # Parse known and unknown arguments
    args, unknown_args = parser.parse_known_args()
    # Convert unknown args to kwargs for Trainer
    trainer_kwargs = {}
    it = iter(unknown_args)
    for key in it:
        if key.startswith('--'):
            key = key.lstrip('--')
            try:
                value = next(it)
                if value.isdigit():
                    value = int(value)
                elif value.replace('.', '', 1).isdigit():
                    value = float(value)
            except StopIteration:
                raise ValueError(f"Argument {key} lacks a corresponding value.")
            trainer_kwargs[key] = value

    set_seed(42)  # Set seeds for reproducibility

    model = RetrievalModel(args)
    print("trainer_kwargs", trainer_kwargs)

    # checkpoint_callback = ModelCheckpoint(monitor="valid/loss", mode="min", save_top_k=0, save_last=True)
    save_pretrained_callback = ModelSaveCallback(save_path=args.model_save_path, save_top_k=args.save_top_k)

    if args.log_with_wandb:
        wandb_logger = WandbLogger(project=args.wandb_project)

    trainer = Trainer(
        default_root_dir=args.model_save_path,
        callbacks=[save_pretrained_callback],
        enable_checkpointing=False,
        logger=wandb_logger if args.log_with_wandb else None,
        **trainer_kwargs)

    if args.mode == 'train':
        trainer.fit(model)

        # :add:
        print(f'## final saving at global_step: {trainer.global_step}')
        trainer.validate(model, dataloaders=model.val_dataloader())
    else:
        trainer.test(model)

if __name__ == "__main__":
    main()
LinWeizheDragon commented 1 month ago

Hi, With our old codebase, post fine-tuning, we performed a thorough performance checking through the whole fine-tuning process to understand the behaviour of Infoseek fine-tuning: image It is clear that after 1k steps, the model degrades so fast due to data imbalance of train/test splits (I personally don't understand why zero-shot is needed for a retriever; but this is their dataset and I can't change it...)

The learning rate in my fine-tuning is 1e-5, while the provided script uses 5e-6 to make the training more stable. I think by changing the hyperparameters (e.g. extending the training time or increasing the learning rate), you will still be able to obtain the improvement. Due to the nature of this dataset, it will be difficult to find a sweet setting for all cases. You have to find an empirical value on your own.

Maxlinn commented 1 month ago

thanks for sharing! may i know how many examples are consumed for 1k steps? according to the paper "fineuning on downstream task" part, for infoseek, 1k steps 4 (n_gpus) 8 (batch_size) * 8 (gradient_accumulation) would consume 256k examples, but there are only 100k examples in m2kr infoseek train set in paper. i am a little confused.

LinWeizheDragon commented 1 month ago

That means some examples were seen more than once. The dataloader randomly draws examples from the whole 100k pool.

Maxlinn commented 3 weeks ago

thanks for the acknowledgement! sorry for the late reply :)