plkmo / NLP_Toolkit

Library of state-of-the-art models (PyTorch) for NLP tasks
Apache License 2.0
109 stars 25 forks source link

Pretrained punctuation model produces mangled output #4

Closed adam-faulkner closed 4 years ago

adam-faulkner commented 4 years ago

Hi, when running the pretrained biLSTM for punctuation-restoration, I get the following output from inferer.infer_sentence("hi how are you"):

Predicted Label: .hI. How.

and inferer.infer_from_file("./data/input.txt", out_file="./data/output.txt") produces the following output.txt:

hi how are you,.hI. How.
i am fine thanks,"I, am fine thanks."

Does the pretrained model not work?

Tested on both OSX and Linux

plkmo commented 4 years ago

What are the params you've used for punctuate.py? For default, I have tested on my side and its fine. I have re-uploaded files again to be sure.

v-iashin commented 4 years ago

Tested it today on my setup. I can confirm the same results as @adam-faulkner

from nlptoolkit.utils.config import Config
from nlptoolkit.punctuation_restoration.trainer import train_and_fit
from nlptoolkit.punctuation_restoration.infer import infer_from_trained

config = Config(task='punctuation_restoration') # loads default argument parameters as above
config.data_path = "./data/train.tags.en-fr.en" # sets training data path
config.batch_size = 32
config.lr = 5e-5 # change learning rate
config.model_no = 1 # sets model to PuncLSTM
inferer = infer_from_trained(config)
inferer.infer_from_file(in_file="./data/input.txt", out_file="./data/output.txt")

Here are the MD5 sums of the unpacked files from the google drive

c1ad3bbcc27df9408a6b4e1580755bf1  ./data/args.pkl
d1f9aac33e3c05ebb6dd1be599cc1071  ./data/eng.pkl
2f00522f3751170a47ae493621d6809b  ./data/idx_mappings.pkl
5ada91777cd0258a2e8060d90d1756af  ./data/input.txt
4720f06853946777fafe07b894f37339  ./data/mappings.pkl
1fc5795d76584ddd1dce55ac5e987010  ./data/test_accuracy_per_epoch_1.pkl
a7b67a5fb8f91993bc3725028af7ddd3  ./data/test_checkpoint_1.pth.tar
a39ab1b199017f4a66bfc19360e5c1ac  ./data/test_losses_per_epoch_1.pkl
989aa99a3eaaf1b66b25b69b087f442a  ./data/test_model_best_1.pth.tar
5d575762da3d9e628f0494b6ce8abeb9  ./data/train.tags.en-fr.en
87c92564facab511b1a2880cda9387a4  ./data/vocab.pkl

conda environment

name: punc_restore
channels:
  - defaults
dependencies:
  - _libgcc_mutex=0.1=main
  - ca-certificates=2020.1.1=0
  - certifi=2020.4.5.1=py36_0
  - ld_impl_linux-64=2.33.1=h53a641e_7
  - libedit=3.1.20181209=hc058e9b_0
  - libffi=3.3=he6710b0_1
  - libgcc-ng=9.1.0=hdf63c60_0
  - libstdcxx-ng=9.1.0=hdf63c60_0
  - ncurses=6.2=he6710b0_1
  - openssl=1.1.1g=h7b6447c_0
  - pip=20.0.2=py36_3
  - python=3.6.10=h7579374_2
  - readline=8.0=h7b6447c_0
  - setuptools=47.1.1=py36_0
  - sqlite=3.31.1=h62c20be_1
  - tk=8.6.8=hbc83047_0
  - wheel=0.34.2=py36_0
  - xz=5.2.5=h7b6447c_0
  - zlib=1.2.11=h7b6447c_3
  - pip:
    - audioread==2.1.8
    - beautifulsoup4==4.9.1
    - blis==0.2.4
    - boto==2.49.0
    - boto3==1.9.238
    - botocore==1.12.253
    - cffi==1.14.0
    - chardet==3.0.4
    - click==7.1.2
    - cycler==0.10.0
    - cymem==2.0.3
    - dask==2.1.0
    - decorator==4.4.2
    - docutils==0.15.2
    - en-core-web-lg==2.1.0
    - fairseq==0.8.0
    - fastbpe==0.1.0
    - fasttext==0.9.1
    - filelock==3.0.12
    - h5py==2.10.0
    - idna==2.9
    - jieba==0.39
    - jmespath==0.10.0
    - joblib==0.15.1
    - kenlm==0.0.0
    - keras==2.3.1
    - keras-applications==1.0.8
    - keras-preprocessing==1.1.2
    - kiwisolver==1.2.0
    - librosa==0.7.0
    - llvmlite==0.32.1
    - matplotlib==3.1.0
    - murmurhash==1.0.2
    - networkx==2.3
    - nltk==3.5
    - numba==0.49.1
    - numpy==1.16.4
    - pandas==0.24.2
    - pillow==7.1.2
    - plac==0.9.6
    - portalocker==1.7.0
    - preshed==2.0.1
    - pybind11==2.5.0
    - pycparser==2.20
    - pyparsing==2.4.7
    - python-dateutil==2.8.1
    - pytorch-nlp==0.4.1
    - pytz==2020.1
    - pyyaml==5.3.1
    - regex==2019.8.19
    - requests==2.23.0
    - resampy==0.2.2
    - s3transfer==0.2.1
    - sacrebleu==1.4.1
    - sacremoses==0.0.34
    - scikit-learn==0.21.2
    - scipy==1.4.1
    - sentencepiece==0.1.83
    - seqeval==0.0.12
    - six==1.12.0
    - soundfile==0.10.2
    - soupsieve==2.0.1
    - spacy==2.1.8
    - srsly==1.0.2
    - thinc==7.0.8
    - tokenizers==0.7.0
    - toolz==0.10.0
    - torch==1.4.0
    - torchtext==0.4.0
    - torchvision==0.5.0
    - tqdm==4.32.1
    - typing==3.7.4.1
    - urllib3==1.25.9
    - wasabi==0.6.0
plkmo commented 4 years ago

Thanks for the detailed info. I have identified the problem now. Its with Config(task='punctuation_restoration'). The max encoder & decoder lengths were 200, but should be set to 80 for the pre-trained puncLSTM to work well. Have fixed with update to set them to 80 by default.

v-iashin commented 4 years ago

I tried and it worked for the sample data as expected.

adam-faulkner commented 4 years ago

Works as expected now, thanks for the fix!