Hironsan / anago

Bidirectional LSTM-CRF and ELMo for Named-Entity Recognition, Part-of-Speech Tagging and so on.
https://anago.herokuapp.com/
MIT License
1.48k stars 368 forks source link

Getting output probabilities for tokens #61

Closed sureshiyengar closed 6 years ago

sureshiyengar commented 6 years ago

Is there any way I can compute output probabilities or scores for each token-label output?

olagus1 commented 6 years ago

I'm also interested in this. Maybe using "softmax" activation at the last dense layer would be an option? I did try it but couldn't make it work.

Hironsan commented 6 years ago

If we choose use_crf=False, we can get probabilities.

Example:

>>> from anago.tagger import Tagger
>>> from anago.trainer import Trainer
>>> from anago.preprocessing import IndexTransformer
>>> from anago.models import BiLSTMCRF
>>> 
>>> p = IndexTransformer()
>>> p.fit(x_train, y_train)
>>> 
>>> model = BiLSTMCRF(char_vocab_size=p.char_vocab_size,
>>>                   word_vocab_size=p.word_vocab_size,
>>>                   num_labels=p.label_size,
>>>                   use_crf=False)
>>> model.build()
>>> model.compile(loss=model.get_loss(), optimizer='adam')
>>> trainer = Trainer(model, preprocessor=p)
>>> trainer.train(x_train, y_train)
>>> 
>>> sent = 'President Obama is speaking at the White House.'
>>> tagger = Tagger(model, preprocessor=p)
>>> tagger.predict_proba(sent)
[[1.32390960e-05 9.97208416e-01 7.74135287e-07 9.90919303e-04
  2.25378317e-04 5.41497429e-04 8.63012276e-04 3.13273013e-05
  1.74037984e-06 1.23720863e-04]
 [1.43429599e-04 1.03146426e-01 9.88335256e-03 5.01271844e-01
  3.57445449e-01 6.94832811e-03 8.67252331e-03 1.16340211e-02
  1.81677286e-04 6.73040922e-04]
 [1.39899294e-05 9.98564661e-01 2.11253905e-07 1.34401664e-04
  3.99311357e-05 3.67401633e-04 6.99305092e-04 1.43571096e-05
  1.80207076e-06 1.63966848e-04]
 [6.68837311e-06 9.99470174e-01 3.49865275e-07 1.23266567e-04
  7.50690233e-05 6.32712763e-05 1.76800706e-04 2.61328278e-05
  3.64134877e-07 5.79322223e-05]
 [4.57728629e-06 9.99763310e-01 6.55857306e-08 1.87382047e-05
  1.94344520e-05 2.07806825e-05 1.01221194e-04 1.60485743e-05
  1.54021791e-07 5.57465974e-05]
 [8.94824007e-06 9.99513030e-01 8.39019890e-07 5.34512656e-05
  1.38184769e-04 6.56975681e-06 6.30745053e-05 1.41227734e-04
  2.12021959e-07 7.45292418e-05]
 [5.65033348e-04 1.35039969e-03 1.81262687e-01 6.73109293e-02
  6.26478195e-01 1.72473097e-04 2.26516346e-03 1.19472988e-01
  4.99634363e-04 6.22511376e-04]
 [1.06257349e-02 3.56446224e-04 2.43019243e-03 1.28470513e-03
  1.00840600e-02 9.53651965e-02 4.64733094e-01 5.29475138e-03
  2.87371993e-01 1.22453794e-01]]
[[0.00995674 0.03049272 0.06292471 0.21905018 0.33271903 0.09431319
  0.145422   0.04879963 0.02932186 0.0269999 ]]

For now, we can't get probabilities when use_crf=True.

with best regards