Open chiragjn opened 3 years ago
Yes, it is planned to integrate a more modular fit (and evaluators) in version 2.1. They should provide options to log to tensorboard / wandb or to your own custom code.
A PR on this would be really nice. I haven't started yet on this topic.
Just saw 2.0 released a few days ago! :D I forked it and have started refactoring code to add events via callbacks. I'll update when I have something working. In the meantime, I kinda hacked the loss function to be able to do some logging
import torch
from sentence_transformers import losses as st_losses
from timeit import default_timer as timer
class LoggedDenoisingAutoEncoderLoss(st_losses.DenoisingAutoEncoderLoss):
def __init__(self,
model: st.SentenceTransformer,
decoder_name_or_path: str = None,
tie_encoder_decoder: bool = True,
record_every_n_steps: int = 1000):
super().__init__(model=model, decoder_name_or_path=decoder_name_or_path, tie_encoder_decoder=tie_encoder_decoder)
self._record_every_n_steps = max(1, record_every_n_steps)
self._step_counter = 0
self._mtime = 0
self._running_loss = torch.tensor(0.0)
def _pre_forward(self):
if self.encoder.training:
if self._step_counter == 0 or self._step_counter % self._record_every_n_steps == 0:
self._mtime = timer()
self._running_loss = torch.tensor(0.0)
self._step_counter += 1
def _post_forward(self, loss: torch.Tensor):
if self.encoder.training:
self._running_loss += loss.item()
if self._step_counter % self._record_every_n_steps == 0:
metrics = {
'step_range': f'{self._step_counter - self._record_every_n_steps + 1}-{self._step_counter}',
'time_elapsed': timer() - self._mtime,
'total_train_loss': self._running_loss.item(),
'avg_train_loss_per_step': self._running_loss.item() / self._record_every_n_steps,
}
logger.info(f'{type(self).__name__}: {metrics}')
def forward(self, sentence_features: Iterable[Dict[str, torch.Tensor]], labels: torch.Tensor) -> torch.Tensor:
self._pre_forward()
loss: torch.Tensor = super().forward(sentence_features=sentence_features, labels=labels)
self._post_forward(loss=loss.clone().detach().cpu())
return loss
If there's no open PR for this, I'll be happy to take it up. I currently subclass SentenceTransformer and CrossEncoder and implement hard-coded wandb logging. I think it can be generalized to custom callbacks for tensorboard/wandb.
@skewwhiff Please feel free to take this up, sorry for the lack of updates. I did some prototyping back then but I was not happy with the architecture myself. I think API like fastai or pytorch-lightning would be pretty good. I'll be happy to help in any way I can
All right. Will open a new PR for generic logging. @nreimers . I have a broad overview of what to do.
Do you have anything else in mind wrt logging?
chiragjn's loss function looks great, but I was turned away by how complicated it looked. For anyone else looking for a quick hack, I think this is the basic idea:
class LoggingLoss:
def __init__(self, loss_fn, wandb):
self.loss_fn = loss_fn
self.wandb = wandb
def __call__(self, logits, labels):
loss = self.loss_fn(logits, labels)
self.wandb.log({ 'train_loss': loss })
return loss
# ...
wandb.init()
wandb.watch(model.model)
model.fit(
# ...
loss_fct=LoggingLoss(torch.nn.BCEWithLogitsLoss(), wandb),
# ...
)
Looking forward to hooks and proper integration! :)
Commenting to keep an eye on this issue, I did it by hacking the fit() function from the SentenceTransformer.py file but I guess it makes more sense to use the Losses to do the logging.
Hi! Are there any updates on this issue?
any updates ? it is a bit crazy that we can't get training loss ?!
Hello, Is there a plan to make
.fit
more modular?For context, I am integrating the library in an async worker and I want to use python/tensorboard/wandb logging to log metrics, losses, etc every n steps or epochs. The fit function at the moment is not modular enough for me to inherit and override the right points. There is callback fn support but that only works if some evaluator is provided. For e.g. libraries like fastai, pytorch lightning provide callbacks for before/after batch/epoch etc
[1] https://docs.fast.ai/callback.progress.html [2] https://pytorch-lightning.readthedocs.io/en/latest/extensions/callbacks.html
If acceptable I can help work on a PR for this :D