Gnomeek / ChineseSentimentAnalysiswithBERT

基于BERT模型的中文文本情感分类
Apache License 2.0
34 stars 2 forks source link

pytorch implementation? #1

Open lucasjinreal opened 4 years ago

lucasjinreal commented 4 years ago

Does there any pytorch implementation?

Gnomeek commented 4 years ago

This repo is based on Tensorflow 1.13. You can check Pytorch implementation on this repo. Also, you can load tensorflow checkpoint to convert tf model to Pytorch model, check this repo for more information. Since bert-as-service also support Pytorch input, quote :

Q: Do I need Tensorflow on the client side? A: No. Think of BertClient as a general feature extractor, whose output can be fed to any ML models, e.g. scikit-learn, pytorch, tensorflow. The only file that client need is client.py. Copy this file to your project and import it, then you are ready to go.

I think that it's easy to modify and reuse.

lucasjinreal commented 4 years ago

thanks for your reply. I need Chinese pretrained model so I'd better stick with tensorflow. I tried tf1.14 it outputs some error but I have fixed it. And I have frozen the model to pb file.

Question is, any snippets to load this pb file for inference only? I couldn't found any simple file to predict only by loading the pb file.

Gnomeek commented 4 years ago

Well, .pb file actually is needed by bert-base. you can check this repo for more information. I forgot to upload .sh script using .pb file, but the basic idea is similar with run.sh, format as below:

export BERT_BASE_DIR=yada yada
export TRAINED_CLASSIFIER=yada yada
export EXP_NAME=yada yada

bert-base-serving-start \
    -model_dir $TRAINED_CLASSIFIER/$EXP_NAME \
    -bert_model_dir $BERT_BASE_DIR \
    -model_pb_dir $TRAINED_CLASSIFIER/$EXP_NAME \
    -mode CLASS \
    -max_seq_len 128 \
    -http_port 8091 \
    -port 5575 \
    -port_out 5576 \
    -device_map 1

install bert-base by pip install bert-base==0.0.9 -i https://pypi.python.org/simple

In this way, you can predict and test the model simply by curl or API

curl -X POST http://114.123.152.111:8091/encode \ 
  -H 'content-type: application/json' \
  -d '{"id": 1, "texts": ["今天天气不错"], "is_tokenized": false}'
# IP address and protocol may need to change based on your situation.
Gnomeek commented 4 years ago

bertsvr.sh is pushed. you can check it under svr folder in master branch. You can change the parameter based on your model(like max_len)

To be clarify, the difference between bert-as-service and bert-base is that bert-as-service can only load pretrained file as input while bert-base can load both pretrained file and self-trained model(like .pb file). You may change parameters to run bert-base like bert-as-service.

lucasjinreal commented 4 years ago

Thanks for this helpful suggest, it's useful, but I got an error when start service:

./bertsvr.sh                                                                                                                                                                     2 master!?
start BERT server...
2020-04-06 12:52:18.136044: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library libcudart.so.10.2
usage: /usr/local/bin/bert-base-serving-start -model_dir ../data/output -bert_model_dir ../chinese_L-12_H-768_A-12 -model_pb_dir ../data -mode CLASS -max_seq_len 128 -http_port 8091 -port 5575 -port_out 5576 -cpu -device_map 1
                 ARG   VALUE
__________________________________________________
      bert_model_dir = ../chinese_L-12_H-768_A-12
           ckpt_name = bert_model.ckpt
         config_name = bert_config.json
                cors = *
                 cpu = True
          device_map = [1]
                fp16 = False
 gpu_memory_fraction = 0.5
    http_max_connect = 10
           http_port = 8091
        mask_cls_sep = False
      max_batch_size = 1024
         max_seq_len = 128
                mode = CLASS
           model_dir = ../data/output
        model_pb_dir = ../data
          num_worker = 1
       pooling_layer = [-2]
    pooling_strategy = REDUCE_MEAN
                port = 5575
            port_out = 5576
       prefetch_size = 10
 priority_batch_size = 16
     tuned_model_dir = None
             verbose = False
                 xla = False

I:VENTILATOR:[__i:__i:104]:lodding classification predict, could take a while...
WARNING: Logging before flag parsing goes to stderr.
I0406 12:52:19.520388 140025864881984 __init__.py:104] lodding classification predict, could take a while...
I:VENTILATOR:[__i:__i:111]:contain 0 labels:dict_values(['0', '1'])
I0406 12:52:19.525076 140025864881984 __init__.py:111] contain 0 labels:dict_values(['0', '1'])
pb_file exits ../data/classification_model.pb
I:VENTILATOR:[__i:__i:114]:optimized graph is stored at: ../data/classification_model.pb
I0406 12:52:19.627817 140025864881984 __init__.py:114] optimized graph is stored at: ../data/classification_model.pb
I:VENTILATOR:[__i:_ru:148]:bind all sockets
I0406 12:52:19.628451 140024723961600 __init__.py:148] bind all sockets
Exception in thread Thread-1:
Traceback (most recent call last):
  File "/usr/lib/python3.6/threading.py", line 916, in _bootstrap_inner
    self.run()
  File "/usr/local/lib/python3.6/dist-packages/bert_base/server/__init__.py", line 134, in run
    self._run()
  File "/usr/local/lib/python3.6/dist-packages/zmq/decorators.py", line 75, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/zmq/decorators.py", line 75, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/zmq/decorators.py", line 75, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/bert_base/server/zmq_decor.py", line 27, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/bert_base/server/__init__.py", line 150, in _run
    addr_front2sink = auto_bind(sink)
  File "/usr/local/lib/python3.6/dist-packages/bert_base/server/helper.py", line 182, in auto_bind
    socket.bind('ipc://{}'.format(tmp_dir))
  File "zmq/backend/cython/socket.pyx", line 550, in zmq.backend.cython.socket.Socket.bind
  File "zmq/backend/cython/checkrc.pxd", line 26, in zmq.backend.cython.checkrc._check_rc
zmq.error.ZMQError: Input/output error

Do u got any idea?

lucasjinreal commented 4 years ago

Fxied it, when test got an error:

I0406 13:02:47.061372 140011026507520 __init__.py:537] ready and listening!
^TI:PROXY:[htt:enc: 47]:new request from 127.0.0.1
I0406 13:04:28.258268 140012192503552 http.py:47] new request from 127.0.0.1
None
E:PROXY:[htt:enc: 54]:error when handling HTTP request
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/bert_base/server/http.py", line 49, in encode_query
    return {'id': data['id'],
TypeError: 'NoneType' object is not subscriptable
E0406 13:04:28.273071 140012192503552 http.py:54] error when handling HTTP request
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/bert_base/server/http.py", line 49, in encode_query
    return {'id': data['id'],
TypeError: 'NoneType' object is not subscriptable
I0406 13:04:28.323652 140012192503552 _internal.py:113] 127.0.0.1 - - [06/Apr/2020 13:04:28] "POST /encode HTTP/1.1" 400 -

curl -X POST http://localhost:8091/encode \ -H 'content-type: application/json' \ -d '{"id": 1, "texts": ["今天天气不错"], "is_tokenized": false}'

lucasjinreal commented 4 years ago

Fixed it, the model whatever inputs are, output label always 0...

image

image image

this is eval result on my model:

eval_accuracy = 0.9783315
eval_f1 = 0.97806644
eval_loss = 0.060206193
eval_precision = 0.995364
eval_recall = 0.96135986
global_step = 11998
loss = 0.060203094
Gnomeek commented 4 years ago

It seems to work fine on my model. Here's the screenshot. image I think that maybe something wrong with your label2id.pkl since the label is 0 constantly.

Gnomeek commented 4 years ago

OR maybe the label of your dataset format doesn't show like "1 for positive and 0 for negative". Changing the get_label method of MyProcessor in run_classifier.py may help.

lucasjinreal commented 4 years ago

No, I am using default corpus which is weibo_sentiment100k to train.

You can test my sentence, the model can not handle this situations.

Gnomeek commented 4 years ago

Plz wait for some hours. I need to re-upload data to the cloud server and rerun the model.

lucasjinreal commented 4 years ago

thanks!

Gnomeek commented 4 years ago

It still seems to work fine on my machine... image Here's some prediction of where's the problem:

  1. Something bad happened when divided the dataset. You can delete the first row of every subsets(train, dev, test), It looks like" label x_train" yada yada.

  2. Something bad happened when generated label2id.pkl. It should be approximately 28 bytes(for weibo-senti-100k dataset) image

Hope those suggestions can help you. I will update the newest version of this repo later. Maybe you can start from scratch and re-train the model(After all, shutdown and restart solve lots of weird problems lmao.)

Gnomeek commented 4 years ago

The newest version is updated. I also included the label2id.pkl and divided datasets in data folder. Maybe you can reuse it directly.