ddkang / loss_dropper

Apache License 2.0
51 stars 9 forks source link

Improved Natural Language Generation via Loss Truncation

This is the official repository for Improved Natural Language Generation via Loss Truncation.

We provide code for loss dropping.

What's wrong with log loss?

Neural language models are typically trained via log loss. While straightforward to optimize, even small fractions of noisy data (e.g., misannotations and hallucinated facts) can degrade the performance of log loss. As an alternative, prior work has shown that minimizing the distinguishability of generated samples is a principled and robust loss that can handle invalid references. However, distinguishability has not been used in practice due to challenges in optimization and estimation.

What is loss truncation?

Loss truncation a simple and scalable procedure which adaptively removes high log loss examples as a way to optimize for distinguishability. We demonstrate that loss truncation outperforms existing baselines on distinguishability on a summarization task, and show that samples generated by the loss truncation model have factual accuracy ratings that exceed those of baselines and match human references.

See our paper for full details.

Installation

We require Python 3.5+ and torch 1.0+.

To install loss_dropper, in a virtual environment of your choice, run:

pip install -U git+https://github.com/ddkang/loss_dropper.git

Usage

  1. Import loss_dropper:

    from loss_dropper import LossDropper
  2. Initialize LossDropper:

    self.dropper = LossDropper(dropc=dropc)
  3. Initialize your loss:

    self.criterion = nn.NLLLoss(weight, reduction='none')

    IMPORTANT: loss truncation performs dropping at the sequence level. The reductions other than none will aggregate over the wrong dimensions for truncation.

  4. Do loss dropping:

    loss = loss.view(-1, batch_size)  # view by batch size
    loss = loss.mean(dim=0)  # aggregate by sequence
    mask = self.dropper(loss)  # The dropper returns a mask of 0s where data should be dropped.
    loss *= mask  # Mask out the high losses
    loss = loss.mean()  # Aggregate

    IMPORTANT: depending on how your loss functions, you may have to aggregate in different ways.

Citation

If you find this useful in your research, please consider citing:

@article{kang2020improved,
  title={Improved Natural Language Generation via Loss Truncation},
  author={Daniel Kang and Tatsunori Hashimoto},
  journal={ACL},
  year={2020}
}