AIPHES / DiscoScore

DiscoScore: Evaluating Text Generation with BERT and Discourse Coherence
33 stars 6 forks source link

Interface for Truncation #6

Open UntotaufUrlaub opened 1 year ago

UntotaufUrlaub commented 1 year ago

Hi,

I think an interface to configure whether the input is truncated or not, would be pretty handy.

Of course one could handle the truncation before passing the texts into the metric, but that seems like unnecessary overhead for a use case which could appear quite often. I have seen one fork started fiddling in this direction, but I assume this approach is probably not working. They configured the tokenizer to perform truncation, but if I see things right, the input text is not only processed using the tokenizer. For example DS_Sent in line 153: the creation of the entity graph.

I am wondering what you think would be the best approach to handle this issue? I assume it would be best to handle the truncation of the text as early as possible, so not every code bit which uses it has to be altered and blown up. So I see 2 options:

  1. Pass a truncation argument to DiscoScorer() which is then used in all the methods (LexicalChain, DS_SENTNN ...) to decide truncation. This would make all calls of one scorer be aligned and the individual calls remain with their easy interface. But i am not sure if this is flexible enough and whether truncation makes sense for the non DS... methods.
  2. Add a truncation argument to the 4 DS_... methods.

Both methods would perform truncation in the scorer file and pass the preprocessed texts into the code DS logic. The disadvantage is that one has to tokenize, truncate and repack the tokens into text, which seems a bit of computational overhead.

andyweizhao commented 1 year ago

Hi,

When dealing with long input texts, I think the trivial solution is to use Longformer or its successors. Ofc, there is still a limit in sequence length, so truncation is still needed at some point.

For truncation, I think the right solution is to handle truncation at very beginning before passing on input texts to the metric. This is because the metric tokenizes every input text twice, which in turn produces words (via UDPiple) and word pieces (via BERT-tokenizer) for different purposes. It is necessary to ensure that the same text undergoes this two-step tokenization process.

Hope this helps.

UntotaufUrlaub commented 1 year ago

Hi,

I created a fork for this topic. Would you please have a look? If you think the change is of any help I could open a pull request. I added truncation and adjusted the code that one can use Longformer without manual code change. Do you know if the suggested Longformer model performs comparable to Conpono and BERT-NLI?

andyweizhao commented 1 year ago

Hi, thank you for this! I think your change is correct. Did you run tests for verification? Regarding Longformer, the answer is negative -- it underperforms Conpono and BERT-NLI on SUMMEval.

UntotaufUrlaub commented 1 year ago

Hi, Yes, I did run some test. The truncation part seems to work fine. The longformer part has some issues: One bug I could not solve yet (maybe running on CPU will makes debugging easier :) ). And the config of the downloaded models needs to be adjusted to work with Auto-Classes. See the open issue in the fork.

If you are interested I could share the bug with you once I got time.

Thanks for the performance info.

andyweizhao commented 1 year ago

Hi, I am happy to look into the issue once you provide me with details.

UntotaufUrlaub commented 1 year ago

I create an example showing the error. Replace the ... in the doc definition with the content of this attached file. (It is to long to display here.)

from disco_score import DiscoScorer

doc = """..."""
summ = """the daily and sunday politics are on-air six days a week for much of the year reporting the political news from westminster and beyond."""

DiscoScorer(    # device='cuda:0',
                    device="cpu",
                    model_name='allenai/longformer-base-4096',
                    truncation=None,
                ).DS_SENT_NN(summ, [doc])

The final error is

File "/.../python3.11/site-packages/torch/nn/functional.py", line 2210, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)

IndexError: index out of range in self

(Should i add the whole stacktrace?)

I tried to google it, but do to a lack of experience I couldn't find the connection of the discussed issues and this problem.

I tried to analyse the error using this script:

from disco_score import DiscoScorer

doc = """..."""
summ = """the daily and sunday politics are on-air six days a week for much of the year reporting the political news from westminster and beyond."""

print("test 1")
try:
    DiscoScorer(    # device='cuda:0',
                    device="cpu",
                    model_name='allenai/longformer-base-4096',
                    truncation=None,
                ).DS_SENT_NN(summ, [doc[:len(doc) // 8]])
    print("Passed")
except:
    print("Error")

print("test 2")
try:
    DiscoScorer(    # device='cuda:0',
                    device="cpu",
                    model_name='allenai/longformer-base-4096',
                    truncation=None,
                ).DS_SENT_NN(summ, [doc[len(doc) // 8:len(doc) // 4]])
    print("Passed")
except:
    print("Error")

print("test 3")
try:
    DiscoScorer(    # device='cuda:0',
                    device="cpu",
                    model_name='allenai/longformer-base-4096',
                    truncation=None,
                ).DS_SENT_NN(summ, [doc[:len(doc) // 4]])
    print("Passed")
except:
    print("Error")

print("test 4")
try:
    DiscoScorer(    # device='cuda:0',
                    device="cpu",
                    model_name='allenai/longformer-base-4096',
                    truncation=3500,
                ).DS_SENT_NN(summ, [doc[:len(doc) // 4]])
    print("Passed")
except:
    print("Error")

print("test 5")
try:
    DiscoScorer(    # device='cuda:0',
                    device="cpu",
                    model_name='allenai/longformer-base-4096',
                    truncation=4000,
                ).DS_SENT_NN(summ, [doc[:len(doc) // 4]])
    print("Passed")
except:
    print("Error")

Result at my machine:

test 1
Passed
test 2
Passed
test 3
Error
test 4
Passed
test 5
Error

As test 3 produces an error while test 1 and 2 don't, it seems the problem is not about the content, but about the length. So maybe longformer triggers an error for too long texts. Test 4 passed shows that the error can be circumvented by truncating the text. Somehow I need to truncate to a smaller size as I would expect. I thought truncation to 4000 should be fine for longformer 4096, but test 5 fails. So if the cause of the bug is the input length than the truncation code seems flawed, at least when combined with longformer. Maybe somewhere a mismatch between tokenizers.

What is your opinion on this strange behavior?