praeclarum / transformers-js

Browser-compatible JS library for running language models
MIT License
217 stars 17 forks source link

flan t5 support #5

Open loretoparisi opened 1 year ago

loretoparisi commented 1 year ago

Assumed the tokenizer is the quite the same, is flan t5 supported?

loretoparisi commented 1 year ago

What I have done to load lan-t5-small was to add the ReplaceTokenProcessor in tokenizers.js

class ReplaceTokenProcessor extends TokenProcessor {
    constructor(pattern, content) {
        super();
        this.pattern = pattern;
        this.content = content;
    }
    normalize(text) {
        return text.replace(this.pattern.Regex, this.content);
    }
}

that is one of the normalizers loaded by the /models/flan-t5-small-tokenizer.json tokenizer config file in the method TokenProcessor.fromConfig, where I have added a guard on config.pretokenizers and config.normalizers for case 'Sequence'

class TokenProcessor {
    static fromConfig(config) {
        console.log(`TokenProcessor type:${config.type}`,config);
        switch (config.type) {
            case "Metaspace":
                return new MetaspaceTokenProcessor(config.add_prefix_space, config.replacement, config.str_rep);
            case "Precompiled":
                return new PrecompiledTokenProcessor(config.precompiled_charsmap);
            case 'Sequence':
                if (config.pretokenizers) {
                    return new SequenceTokenProcessor(config.pretokenizers.map(x => TokenProcessor.fromConfig(x)));
                }
                else if (config.normalizers) {
                    return new PrecompiledTokenProcessor(config.normalizers.map(x => TokenProcessor.fromConfig(x)));
                }
            case "WhitespaceSplit":
                return new WhitespaceSplitTokenProcessor();
            case "Replace":
                return new ReplaceTokenProcessor(config.pattern, config.content);
            default:
                throw new Error('Unknown token processor type: ' + config.type);
        }
    }
}

The loading session in ORT seems to go quite well:

Loading session from /models/flan-t5-small-encoder-quantized.onnx
transformers.js:3 Loading session from /models/flan-t5-small-init-decoder-quantized.onnx
transformers.js:3 Loading session from /models/flan-t5-small-decoder-quantized.onnx
transformers.js:7 Session loaded from /models/flan-t5-small-encoder-quantized.onnx
transformers.js:7 Session loaded from /models/flan-t5-small-init-decoder-quantized.onnx
transformers.js:36 Loading model google/flan-t5-small... 50%
transformers.js:7 Session loaded from /models/flan-t5-small-decoder-quantized.onnx
transformers.js:36 Loading model google/flan-t5-small... 75%
transformers.js:36 Loading model google/flan-t5-small... 100%

but it seems that the tokenization does not work properly: while the input sequence was correct (the same encoded sequence of ids of the t5-small model):

13959 4338 2 235 2 371 60 5457 10 2 634 2 7846 15 2 159 2 9 2 26 6604 2 1161 222 5 1

the output looks like quite strainge (for The universe is a dark forest.)

0 3 2 3 2 3 2 3 2 3 2 3 2 3 2 3 2 3 2 3 2 3 2 3 2 3 2 3 2 3 2 3 2 3 2 3 2 3 2 3 2 3 2 3 2 3 2 3 2 3 2

that is pretty weird since it uses the same T5ForConditionalGeneration class for inference according to the docs:

from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small", device_map="auto", load_in_8bit=True)

input_text = "translate English to German: How old are you?"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")

outputs = model.generate(input_ids)
print(tokenizer.decode(outputs[0]))

so, when in Transformers-js using

new T5ForConditionalGeneration(encoderSession, initDecoderSession, decoderSession);

it should work unless something under the hood takes place in fromPretrained(modelId, modelsPath, progressAsyncCallback) method or in the generate(inputTokenIds, options, progressAsyncCallback) for the flan-t5 model flavour of ther t5 architecture.