Distilbert-punctuator is a python package provides a bert-based punctuator (fine-tuned model of pretrained huggingface DistilBertForTokenClassification
) with following three components:
Qishuai/distilbert_punctuator_en
📎 Model detailsQishuai/distilbert_punctuator_zh
📎 Model detailsmodel examples in huggingface web page.
English model
Simplified Chinese model
pip install distilbert-punctuator
for directly usage of punctuator.pip install distilbert-punctuator[data_process]
.pip install distilbert-punctuator[training]
make install
Component for pre-processing the training data. To use this component, please install as pip install distilbert-punctuator[data_process]
The package is providing a simple pipeline for you to generate NER
format training data.
examples/data_sample.py
Component for providing a training pipeline for fine-tuning a pretrained DistilBertForTokenClassification
model from huggingface
.
The latest version has the implementation of R-Drop
enhanced training.
R-Drop github repo
Paper of R-Drop
examples/english_train_sample.py
Arguments required for the training pipeline.
training_corpus(List[List[str]])
: list of sequences for training, longest sequence should be no longer than pretrained LM # noqa: E501validation_corpus(List[List[str]])
: list of sequences for validation, longest sequence should be no longer than pretrained LM # noqa: E501training_tags(List[List[int]])
: tags(int) for trainingvalidation_tags(List[List[int]])
: tags(int) for validationmodel_name_or_path(str)
: name or path of pre-trained modeltokenizer_name(str)
: name of pretrained tokenizerepoch(int)
: number of epochbatch_size(int)
: batch sizemodel_storage_dir(str)
: fine-tuned model storage pathlabel2id(Dict)
: the tags label and id mappingearly_stop_count(int)
: after how many epochs to early stop training if valid loss not become smaller. default 3 # noqa: E501gpu_device(int)
: specific gpu card index, default is the CUDA_VISIBLE_DEVICES from environwarm_up_steps(int)
: warm up steps.r_drop(bool)
: whether to train with r-dropr_alpha(int)
: alpha value for kl divengence in the loss, default is 0plot_steps(int)
: record training status to tensorboard among how many stepstensorboard_log_dir(Optional[str])
: the tensorboard logs output directory, default is "runs"addtional_model_config(Optional[Dict])
: additional configuration for modelYou can also train your own NER models with the trainer provided in this repo.
The example can be found in notebooks/R-drop NER.ipynb
Validation of fine-tuned model
examples/train_sample.py
evaluation_corpus(List[List[str]])
: list of sequences for evaluation, longest sequence should be no longer than pretrained LM's max_position_embedding(512)evaluation_tags(List[List[int]])
: tags(int) for evaluation (the GT)model_name_or_path(str)
: name or path of fine-tuned modeltokenizer_name(str)
: name of tokenizerbatch_size(int)
: batch sizelabel2id(Optional[Dict])
: label2id. Default one is from model config. Pass in this argument if your model doesn't have a label2id inside configgpu_device(int)
: specific gpu card index, default is the CUDA_VISIBLE_DEVICES from environComponent for providing an inference interface for user to use punctuator.
+----------------------+ (child process)
| user application | +-------------------+
+ + <---------->| punctuator server |
| +inference object | +-------------------+
+----------------------+
The punctuator will be deployed in a child process which communicates with main process through pipe connection.
Therefore user can initialize an inference object and call its punctuation
function when needed. The punctuator will never block the main process unless doing punctuation.
There is a graceful shutdown
methodology for the punctuator, hence user dosen't need to worry about the shutting-down.
examples/inference_sample.py
Arguments required for the inference pipeline.
model_name_or_path(str)
: name or path of pre-trained modeltokenizer_name(str)
: name of pretrained tokenizertag2punctuator(Dict[str, tuple])
: tag to punctuator mapping.
dbpunctuator.utils provides two default mappings for English and Chinese
NORMAL_TOKEN_TAG = "O"
DEFAULT_ENGLISH_TAG_PUNCTUATOR_MAP = {
NORMAL_TOKEN_TAG: ("", False),
"COMMA": (",", False),
"PERIOD": (".", True),
"QUESTIONMARK": ("?", True),
"EXLAMATIONMARK": ("!", True),
}
DEFAULT_CHINESE_TAG_PUNCTUATOR_MAP = {
NORMAL_TOKEN_TAG: ("", False),
"C_COMMA": (",", False),
"C_PERIOD": ("。", True),
"C_QUESTIONMARK": ("? ", True),
"C_EXLAMATIONMARK": ("! ", True),
"C_DUNHAO": ("、", False),
}
for own fine-tuned model with different tags, pass in your own mapping
tag2id_storage_path(Optional[str])
: tag2id storage path. Default one is from model config. Pass in this argument if your model doesn't have a tag2id inside config