deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.14k stars 658 forks source link

Support for overflow tokens in HuggingFaceTokenizer. #1996

Closed demq closed 2 years ago

demq commented 2 years ago

Description

The python api of the HuggingFace Tokenizer allows to tokenize inputs longer than the maxLength (<= maxModelLenght), for example when dealing with long contexts in QA by returning encodings of overlapping input chunks of maxLength. The offsets of consecutive chunks is set with the stride parameter. An implementation of this functionality in DJL will help to improve the tokenization of long inputs, where a user-implemented work around would involve recreation of a similar logic outside of the HuggingFaceTokenizer class, necessitating multiple tokenizations of the input.

Who will benefit from this enhancement? The users of the HuggingFaceTokenizer who need to ensure that longer than maximum model size inputs can be fully processed through the nlp models, rather than relying on simple truncation of the inputs and potentially throwing out the relevant parts of the inputs.

References

siddvenk commented 2 years ago

Thanks for the request @demq - I made a PR to add support to stride and overflow in our Tokenizer api.

siddvenk commented 2 years ago

The PR has been merged, so after our nightly build runs it should be available to use from 0.19.0-SNAPSHOT. Let me know if you run into any issues using this functionality - there seem to be some limitations/strange behavior with how stride and max length are modified when using pairs of sentences in the core tokenizer repo, so those issues will exist here as well.

demq commented 2 years ago

Hi @siddvenk ,

Thank you very much for implementing this options, it is really helpful. I have pulled the latest snapshot with your changes, but somehow I keep getting errors when running the encoding with these changes. I just took the code from your test: HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.builder() .optTokenizerName("bert-base-cased") .optAddSpecialTokens(false) .optTruncation(true) .optMaxLength(8) .optStride(2) .build() ; String text = "Hello there my friend I am happy to see you"; String textPair = "How are you my friend"; Encoding[] overflowing = tokenizer.encode(text, textPair).getOverflowing();

    And I get the following exception:
    `Exception in thread "main" java.lang.UnsatisfiedLinkError: 'long[] ai.djl.huggingface.tokenizers.jni.TokenizersLibrary.getOverflowing(long)'
at ai.djl.huggingface.tokenizers.jni.TokenizersLibrary.getOverflowing(Native Method)
at ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.toEncoding(HuggingFaceTokenizer.java:422)
at ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.encode(HuggingFaceTokenizer.java:211)
at ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.encode(HuggingFaceTokenizer.java:222)`

Somehow the error occurs trying to fetch the overfloingHandles: `long[] overflowingHandles = TokenizersLibrary.LIB.getOverflowing(encoding);`

I tried other models/settings (stride =0, truncation and padding on or off) - the same result.

Any suggestions what I am doing wrong?
demq commented 2 years ago

Another question regarding HuggingFaceTokenizer, the encoding options such as maxLength/truncation/padding are not dynamically adjustable in DJL, while in python interface these can be set for each call to encoding. I was wondering if these settings can also be allowed to be changed dynamically in DJL, by for example, creating getters/setters for these parameters?

In a scenario where we want to batch-process the inputs and need to "pad" them all to the the same length, we often want to choose the maxLength to be the smallest possible to fit the longest of the inputs.

Setting the maxLength=modelMaxLength would work, but this incurs a large inference time penalty if the inputs are significantly shorter than the modelMaxLength. This is a simple enough task that I can create a PR if the authors are happy with it.

frankfliu commented 2 years ago

@demq Please clean your cache folder, cache folder is not automatically updated for snapshot version.

rm -rf ~/.djl.ai/tokenizers
demq commented 2 years ago

Thanks @frankfliu - tried that too with no success. I can see the updated code for the overflowing tokens in the updated tokenizers.jar, the error occurs in the jni call ai.djl.huggingface.tokenizers.jni.TokenizersLibrary.getOverflowing(Native Method)

frankfliu commented 2 years ago

Are you build DJL from source?

cd extensions/tokenizers
rm -rf jnilib
demq commented 2 years ago

No, I am just pulling the latest snapshots through maven: ai.djl.huggingface:tokenizers:0.19.0-SNAPSHOT, the same for api and pytorch.engine.

frankfliu commented 2 years ago

You can check your cache:

cd ~/.djl.ai/tokenizers/0.12.0-0.19.0-SNAPSHOT-osx-x86_64
nm libtokenizers.dylib| grep getOverflowing
demq commented 2 years ago

Yeap, nm shows no result for "getOverflowing". I did rm -rf ~/.djl.ai Ran the code, and got the following logs:

4:53:34.222 [main] INFO  ai.djl.huggingface.tokenizers.jni.LibUtils - Extracting native/lib/osx-aarch64/libtokenizers.dylib to cache ...
Exception in thread "main" java.lang.UnsatisfiedLinkError: 'long[] ai.djl.huggingface.tokenizers.jni.TokenizersLibrary.getOverflowing(long)'
    at ai.djl.huggingface.tokenizers.jni.TokenizersLibrary.getOverflowing(Native Method)
    at ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.toEncoding(HuggingFaceTokenizer.java:422)
    at ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.encode(HuggingFaceTokenizer.java:211)
    at ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.encode(HuggingFaceTokenizer.java:222)

It looks like the library gets downloaded, but somehow it is missing the getOverflowing function:

 nm libtokenizers.dylib  | grep getTokenCharSpans
  0000000000026f34 T _Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getTokenCharSpans
nm libtokenizers.dylib  | grep getOverflowing
frankfliu commented 2 years ago

Oh, you are using mac M1. Please try it tomorrow. The M1 build is not automated since github doesn't support M1.

demq commented 2 years ago

OK, the tribulations of the M1 :)

demq commented 2 years ago

Thanks for the help @frankfliu - It worked under linux once I've cleared the ~/.djl.ai/tokenizers/

I would also appreciate if you could respond on my comments re setting the tokenizer params - should I create a separate issue about it?

frankfliu commented 2 years ago

We use Builder pattern and the tokenizer is immutable. It's not a good idea to change the settings in the middle of processing. It's error-prone in multi-threading case.

For batch-processing case, if you want to padding to longest, you can easily set padding to true:

HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.builder()
         .optTokenizerName("bert-base-cased")
         .optPadding(true) // pad to longest in the batch
        .build();

if you want pad to fixed length:

HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.builder()
         .optTokenizerName("bert-base-cased")
         .optPadToMaxLength() // pad to modelMaxLength or maxLength
         .optMaxLength(10)
        .build();
demq commented 2 years ago

Thanks a lot the explanation, it makes a lot of sense.

Indeed, your solution with the batchEncode(String[] inputs) would resolve my problem, but it looks like there is no batch encoding support for string pairs analogous to encode(String text, String textPair) . Would it be possible to implement something like batchEncode(QAInput[] inputs) ?

demq commented 2 years ago

I have just ran some tests on Mac M1 - everything works fine, thanks a lot @siddvenk and @frankfliu .

The one edge case I've noticed was a rust exception when the maxLength is not large enough to accommodate the first text + stride (the same happens in python):

 HuggingFaceTokenizer tokenizer =
                        HuggingFaceTokenizer.builder()
                                .optTokenizerName("bert-base-cased")
                                .optAddSpecialTokens(true)
                                .optTruncation(true)
                                .optTruncateSecondOnly()
                                .optPadToMaxLength()
                                .optMaxLength(9)
                                .optStride(5)
                                .build() ;
                String text = "Who is Jane?";
                String textPair = "Jane works as a singer at the opera.";
                Encoding encoded = tokenizer.encode(text, textPair);
                Encoding[] overflowing = encoded.getOverflowing();

I get:

thread '<unnamed>' panicked at 'assertion failed: stride < max_len', /Users/ec2-user/source/djl/extensions/tokenizers/tokenizers/tokenizers/src/tokenizer/encoding.rs:311:9
stack backtrace:
   0:        0x1324ef1c8 - <std::sys_common::backtrace::_print::DisplayBacktrace as core::fmt::Display>::fmt::h188b7ef1c7993e78
   1:        0x132509608 - core::fmt::write::he84a3004e7af3f34
   2:        0x1324ea088 - std::io::Write::write_fmt::h9370b50affaab0be
   3:        0x1324f09c0 - std::panicking::default_hook::{{closure}}::hc074f8023cce83ca
   4:        0x1324f0728 - std::panicking::default_hook::hef854b51b9b79ff2
   5:        0x1324f0e58 - std::panicking::rust_panic_with_hook::h1e59e224d558a492
   6:        0x1324f0d54 - std::panicking::begin_panic_handler::{{closure}}::he1a9d6ab32bfd8c6
   7:        0x1324ef6a4 - std::sys_common::backtrace::__rust_end_short_backtrace::he9b94791b02f48cd
   8:        0x1324f0ae4 - _rust_begin_unwind
   9:        0x1325289a0 - core::panicking::panic_fmt::h9fec86f6a9c4146e
  10:        0x1325288c0 - core::panicking::panic::h02e9fc642940f2ec
  11:        0x132251c78 - tokenizers::tokenizer::encoding::Encoding::truncate::h7de0730a0ffa68c6
  12:        0x1322410c8 - tokenizers::utils::truncation::truncate_encodings::h4cc49821851e4aa2
  13:        0x1321f0adc - tokenizers::tokenizer::TokenizerImpl<M,N,PT,PP,D>::post_process::hd54e9c17bab6afab
  14:        0x1321f1cb4 - tokenizers::tokenizer::TokenizerImpl<M,N,PT,PP,D>::encode::h2b0c6e912566f9be
  15:        0x1321c7920 - _Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_encodeDual
fatal runtime error: failed to initiate panic, error 5