Closed julianmichael closed 6 years ago
Rough, sorry about that. This is the offending line:
https://github.com/allenai/allennlp/blob/master/allennlp/modules/span_pruner.py#L74
Do you want to send a PR which uses
replace_masked_values
with a very negative, but not -inf value instead and your snippet as a small test case?
You are also correct that it could be used for vectors etc. We could call it pruner 👍
I'm actually confused by this - where does the multiplication with 0 happen?
We use a log mask in a few different places, and it appears to work fine there (e.g., for antecedents). There will definitely be -infs in the coreference scores after you add the antecedent log mask, for the current coref model on our regular dataset. If the issue is as simple as you think, why doesn't this crash? I think there's something else going on here.
The multiplication with 0 happens afterwards, sorry:
_, top_mask, _, top_scores = pruner(emb, mask, 2)
top_scores.squeeze(-1) * top_mask
This happened to me when doing torch.nn.functional.binary_cross_entropy_with_logits
over the resulting scores while passing top_span_mask
as the weight
parameter. Was that my mistake?
If you have logits, -inf is the right value, right? Why do you need to multiply by the mask first? Multiplying things in logspace is generally wrong - using -inf
works because those values are typically only passed directly to a loss, or to a softmax, which will convert them back into zeros.
Yes, that's a good point... I will keep digging and try to find something that shows the issue end-to-end.
Ok, so here's where my nan
is coming from at the end of the day. It seems like a reasonable template for doing binary classification on spans with pruning:
import torch
from allennlp.modules.span_pruner import SpanPruner
from allennlp.nn import util
import torch.nn.functional as F
emb = torch.ones([1, 2, 1]) # batch size 1, 2 spans, embedding size 1
gold_labels = torch.zeros([1, 2]).float()
mask = torch.tensor([1, 0]).view(1, 2).float() # only 1 span is present in the instance
scorer = torch.nn.Linear(1, 1)
pruner = SpanPruner(scorer)
_, top_mask, top_indices, top_scores = pruner(emb, mask, 2)
top_gold_labels = util.batched_index_select(gold_labels.unsqueeze(-1), top_indices).squeeze(-1)
loss = F.binary_cross_entropy_with_logits(top_scores.squeeze(-1), top_gold_labels, weight = top_mask)
print(loss)
prints
tensor(nan.)
Assuming we're fine with -inf
appearing the logits, this is presumably a problem with how I'm using masking with F.binary_cross_entropy_with_logits
.
Yeah, that function must be multiplying by the weight, which is giving you nans. What happens if you use our sequence_cross_entropy_with_logits
?
Alright, actually yes, it's a separate issue from it not working with masking. It's that binary_cross_entropy_with_logits
can't deal with -inf
. Looking at the code here, it has the expression (-input - max_val).exp()
, where our input
is -inf
and max_val
is inf
. This produces nan
instead of the desired 0.0
.
It looks like our sequence_cross_entropy_with_logits
is also going to give you nans, because we call torch.nn.functional.log_softmax
, which returns nan
if the whole input is -inf
. So, you need to do something like replace_masked_values
first (or we need to change sequence_cross_entropy_with_logits
to call masked_log_softmax
, which deals with this issue just fine).
sequence_cross_entropy_with_logits
does multiclass whereas I was doing binary classification. I could reshape things a bit but I would rather not since I'm adding more scores to the outputs and doing binary classification in the end anyway. It doesn't seem like there's an alternative to sequence_cross_entropy_with_logits
that does binary cross-entropy. Is there?
Anyway, just replacing the masked values (I did top_span_scores[top_span_scores == float("-inf")] = -1.
and it worked fine) worked around this and fixed my problem at the end of the day. Maybe replace_masked_values
is a bit cleaner so I'll go ahead and do that.
The -inf
s might be something to watch out for in the future... since pytorch doesn't seem robust to them, these kinds of issues might crop up in other places as well.
Oh, good point. And, as I mentioned above, that function won't solve your problem. Either we change all of our places where we add a log mask to using replace_masked_values
, which seems like a net loss for efficiency (but maybe a slight gain for readability...), or you just need to remember to do replace_masked_values
yourself right before the loss computation, if you ever expect to have whole masked decisions.
Actually, replace_masked_values
doesn't work, because it multiplies by the mask. So it just replaces the -inf
s with nan
s.
We've changed that on master; if you update to 0.6.1 (which was just released), it should solve your problem.
Cool :) from searching around the PyTorch forums, it seems they discourage having inf
s anywhere because of the ways they can cause problems. Perhaps it's worth documenting in AllenNLP anywhere a tensor being returned might have inf
values, and whether arguments to library methods can/can't handle inf
.
Ok, I looked through our code, and there are exactly two places in allennlp where this is an issue. One is in the span pruner, and one is in the coref model. In our masked_log_softmax
, we do (mask + 1e-45).log()
, to avoid the -inf
problem. I'll just make a quick PR to fix those other two places, as @DeNeutoy originally suggested =).
Describe the bug In the case when there are fewer than
num_spans_to_keep
total spans in the original text, some padding makes its way into thetop_span_scores
output ofSpanPruner
with scores of-inf
. Even though thetop_spans_mask
output is correct, this is a problem because multiplying the scores by the mask producesnan
in those slots instead of the desired0.0
.To Reproduce In python REPL:
For me, outputs:
though of course the non-inf number is arbitrary.
Expected behavior I think in this case we should replace the
-inf
s with-1
. Because of this issue I had a loss ofnan
that I had to debug until I found this. It should be an easy fix inSpanPruner
. BTW, there's nothing particular to spans inSpanPruner
, is there? Might as well just call itPruner
, right?System (please complete the following information):
Additional context My guess is this hasn't come up before because span pruning was only used with long texts where this doesn't ever happen. It came up for me because I'm using span pruning with a sentence-level model where a few of the sentences have only 2 tokens and are batched with 4-token sentences.