dllllb / pytorch-lifestream

A library built upon PyTorch for building embeddings on discrete event sequences using self-supervision
Apache License 2.0
220 stars 48 forks source link

No way to log metrics in MLMPretrainModule #123

Closed Ilykuleshov closed 1 year ago

Ilykuleshov commented 1 year ago

Currently MLMPretrainModule doesn't log anything but losses. To add metrics, such as AUROC or F1Score, one has to copy & modify the mlm_loss function. There is no other way to extract this model's predictions! The same goes for MLMNSPModule. IMHO, the class would be better off with a shared_step, like the one in ABSModule. An even better option would be to subclass ABSModule (and make ABSModule support single-argument shared_step), or maybe make a separate UnsupervisedABSModule.

Ilykuleshov commented 1 year ago

Excuse me, seems I misunderstood: your implementation you have is actually quite far from ROBERTA, so no such metrics can be logged. Would be great if this was reflected in the docs. Closing this issue, my mistake.

ivkireev86 commented 11 months ago

Thanks for the question. Now I see that "ABSModule" is not a very good solution. There are a few common methods among modules that make sense to put in an abstract class. We plan to gradually remove "ABSModule".