utterworks / fast-bert

Super easy library for BERT based NLP models
Apache License 2.0
1.85k stars 342 forks source link

1.7.0 working notebook/data fails in 1.9.1 #254

Open laurb opened 3 years ago

laurb commented 3 years ago

I have a notebook (and data) that trains a multilabel roberta model successfully with fast-bert 1.7.0.

I have a new machine and set up the conda env with the latest fast-bert (1.9.1) and the notebook fails on the first validation attempt with 'ValueError: Input contains NaN, infinity or a value too large for dtype('float32')'.

I tried going back a release to fast-bert 1.8.0 (and associated packages) but that gets the same error.

I can't find any documentation on what might have changed to cause this. Any ideas?

ddelange commented 3 years ago

Similar issue here (don't think it's an artifact of more recent fast-bert versions though, I think it's upstream).

All predictions are np.nan (for all labels), occurring when I increase the training corpus from ~20k sentences to ~200k sentences (so all works as expected with latest versions and the small corpus).

When I enable FP16 (O1) training all goes well on the 20k training, but for the big one I also get gradient overflow warnings, and the dynamic scaling will continue to converge to zero (1e-230 when I killed it).

Without FP16, it's silent, and indeed fails when trying to validate() on all the predicted nan's after the first epoch with logging_steps=0 and validate=True (on that error, if I then ipdb up inside validate(), see picture, self.predict_batch will also return nan's on any sentences I pass to it).

I'm using sentence_transformers.SentenceTransformer("distilbert-multilingual-nli-stsb-quora-ranking").tokenizer, on these versions:

$ python -c "from packaging.markers import default_environment; print(default_environment())"
{'implementation_name': 'cpython', 'implementation_version': '3.8.5', 'os_name': 'posix', 'platform_machine': 'x86_64', 'platform_release': '4.15.0-52-generic', 'platform_system': 'Linux', 'platform_version': '#56-Ubuntu SMP Tue Jun 4 22:49:08 UTC 2019', 'python_full_version': '3.8.5', 'platform_python_implementation': 'CPython', 'python_version': '3.8', 'sys_platform': 'linux'}
$ pip list | grep -I "torch\|bert\|apex\|transformers"
apex                     0.1
fast-bert                1.9.1
pytorch-lamb             1.0.0
sentence-transformers    0.3.6
torch                    1.6.0
torchvision              0.7.0
transformers             3.1.0
$ /usr/local/cuda/bin/nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2019 NVIDIA Corporation
Built on Wed_Oct_23_19:24:38_PDT_2019
Cuda compilation tools, release 10.2, V10.2.89

nan's predicted after first epoch on the big training corpus: https://github.com/kaushaltrivedi/fast-bert/blob/4496783ad385754e67db944f341cd3d5433ee921/fast_bert/learner_cls.py#L516-L518 image

jackkwok commented 3 years ago

Any solution? I am running in all NaN predictions as well. I trained on the kaggle toxic training set for 6 epochs. I used 1 GPU, no apex.

learner = BertLearner.from_pretrained_model(
    databunch,
    pretrained_path='bert-base-uncased',
    metrics=metrics,
    device=device,
    logger=logger,
    output_dir=OUTPUT_DIR,
    finetuned_wgts_path=None,
    warmup_steps=500,
    multi_gpu=False,
    is_fp16=False, #True,
    multi_label=True,
    logging_steps=50)

NUM_EPOCHS = 6

learner.fit(epochs=NUM_EPOCHS,
    lr=6e-3,
    validate=True,  # Evaluate the model after each epoch
    schedule_type="warmup_cosine",
    optimizer_type="lamb")

Various package versions:

$ pip list | grep -I "torch\|bert\|apex\|transformers"
fast-bert                         1.9.9
pytorch-lamb                      1.0.0
torch                             1.9.0
torchvision                       0.10.0
transformers                      3.0.2

No matter what sentence I gave, the output looks like this: [[('toxic', nan), ('severe_toxic', nan), ('obscene', nan), ('threat', nan), ('insult', nan), ('identity_hate', nan)]]

jackkwok commented 3 years ago

Not sure if this will help solve the original issue but my issue was solved by lower the learning rate to 1e-4.