UKPLab / sentence-transformers

State-of-the-Art Text Embeddings
https://www.sbert.net
Apache License 2.0
15k stars 2.45k forks source link

More modular fit() and support for progress and logging callbacks #1021

Open chiragjn opened 3 years ago

chiragjn commented 3 years ago

Hello, Is there a plan to make .fit more modular?

For context, I am integrating the library in an async worker and I want to use python/tensorboard/wandb logging to log metrics, losses, etc every n steps or epochs. The fit function at the moment is not modular enough for me to inherit and override the right points. There is callback fn support but that only works if some evaluator is provided. For e.g. libraries like fastai, pytorch lightning provide callbacks for before/after batch/epoch etc

[1] https://docs.fast.ai/callback.progress.html [2] https://pytorch-lightning.readthedocs.io/en/latest/extensions/callbacks.html


If acceptable I can help work on a PR for this :D

nreimers commented 3 years ago

Yes, it is planned to integrate a more modular fit (and evaluators) in version 2.1. They should provide options to log to tensorboard / wandb or to your own custom code.

A PR on this would be really nice. I haven't started yet on this topic.

chiragjn commented 3 years ago

Just saw 2.0 released a few days ago! :D I forked it and have started refactoring code to add events via callbacks. I'll update when I have something working. In the meantime, I kinda hacked the loss function to be able to do some logging

import torch
from sentence_transformers import losses as st_losses
from timeit import default_timer as timer

class LoggedDenoisingAutoEncoderLoss(st_losses.DenoisingAutoEncoderLoss):
    def __init__(self,
                 model: st.SentenceTransformer,
                 decoder_name_or_path: str = None,
                 tie_encoder_decoder: bool = True,
                 record_every_n_steps: int = 1000):
        super().__init__(model=model, decoder_name_or_path=decoder_name_or_path, tie_encoder_decoder=tie_encoder_decoder)
        self._record_every_n_steps = max(1, record_every_n_steps)
        self._step_counter = 0
        self._mtime = 0
        self._running_loss = torch.tensor(0.0)

    def _pre_forward(self):
        if self.encoder.training:
            if self._step_counter == 0 or self._step_counter % self._record_every_n_steps == 0:
                self._mtime = timer()
                self._running_loss = torch.tensor(0.0)

            self._step_counter += 1

    def _post_forward(self, loss: torch.Tensor):
        if self.encoder.training:
            self._running_loss += loss.item()

            if self._step_counter % self._record_every_n_steps == 0:
                metrics = {
                    'step_range': f'{self._step_counter - self._record_every_n_steps + 1}-{self._step_counter}',
                    'time_elapsed': timer() - self._mtime,
                    'total_train_loss': self._running_loss.item(),
                    'avg_train_loss_per_step': self._running_loss.item() / self._record_every_n_steps,
                }
                logger.info(f'{type(self).__name__}: {metrics}')

    def forward(self, sentence_features: Iterable[Dict[str, torch.Tensor]], labels: torch.Tensor) -> torch.Tensor:
        self._pre_forward()
        loss: torch.Tensor = super().forward(sentence_features=sentence_features, labels=labels)
        self._post_forward(loss=loss.clone().detach().cpu())
        return loss
skewwhiff commented 2 years ago

If there's no open PR for this, I'll be happy to take it up. I currently subclass SentenceTransformer and CrossEncoder and implement hard-coded wandb logging. I think it can be generalized to custom callbacks for tensorboard/wandb.

chiragjn commented 2 years ago

@skewwhiff Please feel free to take this up, sorry for the lack of updates. I did some prototyping back then but I was not happy with the architecture myself. I think API like fastai or pytorch-lightning would be pretty good. I'll be happy to help in any way I can

skewwhiff commented 2 years ago

All right. Will open a new PR for generic logging. @nreimers . I have a broad overview of what to do.

Do you have anything else in mind wrt logging?

Exr0n commented 2 years ago

chiragjn's loss function looks great, but I was turned away by how complicated it looked. For anyone else looking for a quick hack, I think this is the basic idea:

class LoggingLoss:
    def __init__(self, loss_fn, wandb):
        self.loss_fn = loss_fn
        self.wandb = wandb

    def __call__(self, logits, labels):
        loss = self.loss_fn(logits, labels)
        self.wandb.log({ 'train_loss': loss })
        return loss

# ...

wandb.init()
wandb.watch(model.model)
model.fit(
    # ...
    loss_fct=LoggingLoss(torch.nn.BCEWithLogitsLoss(), wandb),
    # ...
)

Looking forward to hooks and proper integration! :)

SnoozingSimian commented 1 year ago

Commenting to keep an eye on this issue, I did it by hacking the fit() function from the SentenceTransformer.py file but I guess it makes more sense to use the Losses to do the logging.

minhrongcon2000 commented 1 year ago

Hi! Are there any updates on this issue?

johnsonice commented 1 year ago

any updates ? it is a bit crazy that we can't get training loss ?!