joeynmt / joeynmt

Minimalist NMT for educational purposes
Apache License 2.0
664 stars 211 forks source link

Implementing Knowledge distillation #227

Open vmenan opened 2 months ago

vmenan commented 2 months ago

Hi! I came across this library very recently and i am loving it! In my current research I am trying to implement knowledge distillation, which requires multiple datasets to be passed in, here a single step is considered when one batch from each dataset has gone through the model. I am struggling a bit to extend the current joeynmt to achieve this. It would be wonderful if i can get help in this.

may- commented 2 months ago

Hello @vmenan,

I think collate_fn() is an appropriate place to extend. What type of input format do you use? plain text, tsv, or huggingface's dataset? If you could provide some dummy samples of your multi-source input data, I might be able to help you further :)

vmenan commented 2 months ago

Hi @may- Thank you so much for your reply. I apologize for the delayed reply, I will look into the collate_fn(). I will describe the task for you in detail. I create a huggingface's dataset class for english to german translation dataset. for example say i have 3 different english to german datasets (subtiltles, parliment and medical data) named dataset A, B and C, each having 100K datapoints each. What i was thinking of doing was to overide the training manager class, so that i can edit the batch for loop by zip(A,B, C), in this way i get 3 batches at once, I pass it throught the model, get the loss respective of each batch, then perform a weighted sum of the loss and finally take a step. When one batch from each dataset is passed through the model, then i consider that as one training step.

The code is beautifully written in JoeyNMT but i believe i may to make some changes to achieve this. Do you think there is a better way to approach this? Thank you so much for your support!

may- commented 2 months ago

Hi @vmenan,

Ah, ok, now I understand your project better. So, you'd like to compute the loss separately for each dataset, right? Then you need to change the training manager class, indeed.

I thought three scenarios:

  1. simply concatenate the datasets and feed all of the instances sequentially mixed at once.
    • you have no direct control how a batch is comprised.
  2. in the phase of the batch construction, take the instances from each dataset (this is what I've suggested)
    • you can ensure the dataset mixture ratio within a single batch
    • but the loss will be computed from the entire batch instances, not dataset-wise separately.
  3. pass 3 different batches, compute the loss separately and marge them down. (as you described above)
    • more close to the multi-task (joint) learning / loss interpolation

I cannot say which approach is better, it depends on your goal.