simon-ging / coot-videotext

COOT: Cooperative Hierarchical Transformer for Video-Text Representation Learning
Apache License 2.0
288 stars 55 forks source link

Fine tune yc2_100m_coot_vidclip_mart for smaller videos #34

Closed DesaleF closed 3 years ago

DesaleF commented 3 years ago

I was finetuning the provided model with my out dataset. My dataset is really small, so I want to use the pretrained weight. However I got size mismatch of wordembedding. Do you know how I can solve this error? I run python train_caption.py -c config/caption/paper2020/mydata_coot_vidclip_mart.yaml

mart

simon-ging commented 3 years ago

As it says in line 5 of your screenshot you seem to have created a vocabulary with 173 words. Now you are loading the model which has an embedding with 992 words in the vocabulary which does not work.

You could try not loading the embedding and initialize them from scratch or from GloVe weights. Or if all the 173 words appear in the original 992 words vocabulary you could map the embeddings from 992 to 173.

DesaleF commented 3 years ago

@gingsi Thank you for your quick response. I think the better option now is to initialize from Glove weights since all the words doesn't appears in the original vocabulary. So how do I initialize from glove? can you help me with what to change?

simon-ging commented 3 years ago

To build the new vocab and glove embeddings see file mart_build_vocab.py.

Then you need to intercept model loading in trainer_base.py function hook_post_init, there you delete the embedding from the state_dict so it does not crash.

Instead set the glove embeddings as in file mart/model.py function create_mart_model, see the line with model.embeddings.set_pretrained_embedding.

Another thing to note is that if you change the vocabulary the pretrained model will not produce output that makes sense anymore, but it could be worth a shot to then finetune it.

DesaleF commented 3 years ago

@gingsi Thank you very much for your help. It work good with the glove embeddings. for anyone who wants to do the same, this is the change in trainer_base.py

if self.load_model:
    # load model from file. this would start training from epoch 0, but is usually only used for validation.
    self.logger.info(f"Loading model from checkpoint file {self.load_model}")
    model_state = th.load(str(self.load_model))
    ignore_weights = ["decoder.decoder.weight", "decoder.bias", "loss_func.one_hot", "embeddings.word_embeddings.weight"]

    model_state['model']["embeddings.word_embeddings.weight"] = th.from_numpy(th.load(
    Path("cache_caption") / "dataset_vocab_glove.pt")).float()
    self.model_mgr.set_model_state(model_state)
DesaleF commented 3 years ago

well I realize that with the above modification, there is still some error during retrieval so the better one is to put a condition like this:

if self.load_model:
  # load model from file. this would start training from epoch 0, but is usually only used for validation.
  self.logger.info(f"Loading model from checkpoint file {self.load_model}")
  model_state = th.load(str(self.load_model))
  if self.cfg.config_type == "mart":
    model_state['model']["embeddings.word_embeddings.weight"] = th.from_numpy(th.load(
    Path("cache_caption") / "crime_vocab_glove.pt")).float()
  self.model_mgr.set_model_state(model_state)