chengchingwen / Transformers.jl

Julia Implementation of Transformer models
MIT License
520 stars 73 forks source link

Inference API #109

Open chengchingwen opened 2 years ago

chengchingwen commented 2 years ago

mentioned in #108. Currently we don't have an inference api, like the pipeline from huggingface transformers. Right now you need to manually load the model/tokenizer, apply them on the input data, and convert the prediction result to correct/corresponding labels.

Broever101 commented 2 years ago

What's the way to save and load a model currently? I'm saving it like so

BSON.@save bsonname bert_model wordpiece tokenizer

And loading it using load_pretrain_bert(bsonname) but it throws ERROR: UndefVarError: Transformers not defined while loading the tokenizer. Moreover, Flux docs suggest you should do cpu(model) before saving it -- do you think that breaks anything?

chengchingwen commented 2 years ago

Simply BSON.@save and BSON.@load. I guess the error is probably because you forget to using Transformers before loading. And yes it's better to do cpu(model) before saving.

Broever101 commented 2 years ago

Simply BSON.@save and BSON.@load. I guess the error is probably because you forget to using Transformers before loading. And yes it's better to do cpu(model) before saving.

Weird. I have all the dependencies imported in the main module and I'm including the loading script in the module. Anyways, importing them in the REPL solved the issue -- probly a dumb mistake on my part.

Right now, I'm doing this:

struct Pipeline 
    bert_model
    wordpiece
    tokenizer
    bertenc
    function Pipeline(; ckpt::AbstractString="BERT_Twitter_Epochs_1")
        bert_model, wordpiece, tokenizer = load_bert_pretrain("ckpt/$ckpt.bson")
        bert_model = todevice(bert_model)

        bertenc = BertTextEncoder(tokenizer, wordpiece)
        Flux.testmode!(bert_model)
        new(bert_model, wordpiece, tokenizer, bertenc)
    end
end

function (p::Pipeline)(query::AbstractString)
    data = todevice(preprocess([[query], ["0"]]))
    e = p.bert_model.embed(data.input)
    t = p.bert_model.transformers(e, data.mask)

    prediction = p.bert_model.classifier.clf(
        p.bert_model.classifier.pooler(
            t[:,1,:]
        )
    )

    @info "Prediction: " prediction
end

I can do


>p = Pipeline()
>p("this classifier sucks")
┌ Info: Prediction:
│   prediction =
│    2×1 Matrix{Float32}:
│     -0.06848035
└     -2.7152526

I have no idea how to interpret the results (should I uhh take the absolute to know which one hot category is hot??) but is this the correct approach?

chengchingwen commented 2 years ago

Several points:

  1. you don't need to use load_bert_pretrain, you can just use BSON.@load.
  2. BertTextEncoder contains both tokenizer and wordpiece, so you don't need to store all of them.
  3. you would need to do Flux.onecold(prediction) to turn the logits into the index of label.
  4. but the meaning of label is missing here, so you might want to store them in your checkpoint file as well.
stemann commented 1 year ago

Any further thoughts on an Inference API?

@ashwani-rathee and I have been discussing a framework-agnostic API - in particular for inference - that might be relevant wrt. to an inference API for Transformers.jl: https://julialang.zulipchat.com/#narrow/stream/390029-image-processing/topic/DL.20based.20tools/near/383544112