Open Chandrayee opened 3 years ago
I meant to say loss_model.zero_grad(), loss_model.named_parameters() and loss_model.train()
Hi, yes, fit trains the model with all parameters (BERT and 3-way classifier. Pooling has not parameters).
What you can try is to train SBERT on the data and see if it improves the accuracy. Then you have at least a reference for your implementation.
Thanks for the suggestion. I am trying out zero-shot with "distiluse-base-multilingual-cased" and "xlm-r-bert-base-nli-stsb-mean-tokens". It looks like "xlm-r-bert-base-nli-stsb-mean-tokens" has better discriminative ability between entailment, contradiction and neutral (from the cosine similarity). I am a bit confused if the "xlm-r-bert-base-nli-stsb-mean-tokens" is a just xlm-r-bert zero-shot xlm-roberta, or finet-uned xlm-roberta on nli/stsb data or trained with knowledge distillation. What is the base teacher model and student model for "distiluse-base-multilingual-cased"? I read through the documentation and paper, but it was a bit unclear.
Very well-commented repository. I am trying to actually implement a part of your code for Kaggle XNLI data. I could not understand some part of the implementation in SentenceTransformer.py. I can see that you are doing loss_model.backward(). Is the "fit" function used to train the model? Does the training include all the parameters of BERT, Pooling and the 3-way classifier?
Secondly, I am using all the suggested hyperparameters from your paper. So AdamW with lr=1e-5, eps=1e-8 and linear scheduler with warm-up period, mean of the token embeddings from sentence pairs and their absolute difference as features. I am basically copying your implementation. But my data size is much smaller and it is multi-lingual. Its 11520 training examples. My warm-up steps is 100 and I used 2 epochs and a batch_size of 16. But I am not seeing any improvement in accuracy. I am including all the parameters (BERT, Pooling, Classifier) in my training. Should I change the hyperparameters?
I could just fine-tune SBERT, but I wanted to see if I can get some reasonable performance this way.