huggingface / transfer-learning-conv-ai

🦄 State-of-the-Art Conversational AI with Transfer Learning
MIT License
1.73k stars 431 forks source link

Error trying to compute unigram precision/recall/F1 #39

Closed g-karthik closed 5 years ago

g-karthik commented 5 years ago

https://github.com/huggingface/transfer-learning-conv-ai/blob/3bac4924fa307d5e4eba977818b2a29d403e8371/train.py#L235

So I tried extending the above metrics being computed during training to additionally compute validation set unigram precision, recall and F1 since these are supported in pytorch-ignite. I did something like the following:

    eval_metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1), output_transform=lambda x: (x[0][0], x[1][0])),
                    "accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1])),
                    "precision": Precision(output_transform=lambda x: (x[0][0], x[1][0])),
                    "recall": Recall(output_transform=lambda x: (x[0][0], x[1][0]))}
    eval_metrics["ppl"] = MetricsLambda(math.exp, eval_metrics["nll"])

    eval_metrics["f1"] = eval_metrics["precision"] * eval_metrics["recall"] * 2 / (eval_metrics["precision"] +
                                                                                   eval_metrics["recall"] + 1e-20)
    eval_metrics["f1"] = MetricsLambda(lambda t: torch.mean(t).item(), eval_metrics["f1"])

    eval_metrics.update({"average_nll": MetricsLambda(average_distributed_scalar, eval_metrics["nll"], args),
                         "average_accuracy": MetricsLambda(average_distributed_scalar, eval_metrics["accuracy"], args),
                         "average_f1": MetricsLambda(average_distributed_scalar, eval_metrics["f1"], args)})
    eval_metrics["average_ppl"] = MetricsLambda(math.exp, eval_metrics["average_nll"])
    for name, metric in eval_metrics.items():
        metric.attach(evaluator, name)

Does this seem right? @thomwolf

I'm getting a device-side assert failure with the following stack trace when I launch the distributed train job with CUDA_LAUNCH_BLOCKING=1:

Traceback (most recent call last):
  File "train.py", line 402, in <module>
    train()
  File "train.py", line 391, in train
    trainer.run(train_loader, max_epochs=args.n_epochs)
  File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ignite/engine/engine.py", line 446, in run
    self._handle_exception(e)
  File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ignite/engine/engine.py", line 410, in _handle_exception
    raise e
  File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ignite/engine/engine.py", line 429, in run
    self._fire_event(Events.STARTED)
  File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ignite/engine/engine.py", line 345, in _fire_event
    func(self, *(event_args + args), **kwargs)
  File "train.py", line 318, in <lambda>
    trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader))
  File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ignite/engine/engine.py", line 446, in run
    self._handle_exception(e)
  File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ignite/engine/engine.py", line 410, in _handle_exception
    raise e
  File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ignite/engine/engine.py", line 433, in run
    hours, mins, secs = self._run_once_on_dataset()
  File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ignite/engine/engine.py", line 399, in _run_once_on_dataset
    self._handle_exception(e)
  File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ignite/engine/engine.py", line 410, in _handle_exception
    raise e
  File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ignite/engine/engine.py", line 392, in _run_once_on_dataset
    self._fire_event(Events.ITERATION_COMPLETED)
  File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ignite/engine/engine.py", line 345, in _fire_event
    func(self, *(event_args + args), **kwargs)
  File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/autograd/grad_mode.py", line 43, in decorate_no_grad
    return func(*args, **kwargs)
  File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ignite/metrics/metric.py", line 65, in iteration_completed
    self.update(output)
  File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ignite/metrics/precision.py", line 101, in update
    y = to_onehot(y.view(-1), num_classes=num_classes)
  File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ignite/utils.py", line 52, in to_onehot
    return onehot.scatter_(1, indices.unsqueeze(1), 1)
RuntimeError: cuda runtime error (59) : device-side assert triggered at /opt/conda/conda-bld/pytorch_1556653099582/work/aten/src/THC/generic/THCTensorScatterGather.cu:380
terminate called without an active exception

Any ideas on why this is happening?

g-karthik commented 5 years ago

Ah, after a moderate dive through the code and printing relevant tensors, I'm guessing this is happening because lm_labels has a bunch of -1s to mask the context (and the negative candidates) for the LM head, and that is throwing the Precision and Recall classes in pytorch-ignite off, unlike the CrossEntropyLoss class, which has ignore_index=-1 to handle the -1s.

I wonder if there's an equivalent for ignore_index for Precision and Recall, or if I'd have to handle them by modifying the output_transform lambda.

g-karthik commented 5 years ago

Created a transform function like I mentioned above for Precision and Recall and the error is gone!

    def precision_recall_transform(x):
        logits, labels = x[0][0], x[1][0]
        condition = (labels != -1)
        return logits[condition], labels[condition]

However, I see that the actual F1 scores computed this way are unusually low (like 0.015 instead of 0.15 -- the latter is what one would expect for unigram F1 for models like these based on what's reported in literature). Is this something others see as well?

g-karthik commented 5 years ago

Oh never mind, I realized what is computed by the Precision/Recall classes in ignite isn't really the unigram precision/recall I was looking to compute and it makes sense that they're unusually low.