dmlc / gluon-nlp

NLP made easy
https://nlp.gluon.ai/
Apache License 2.0
2.56k stars 538 forks source link

how to inference after finetune a classifier in the script of bert? #662

Closed Gpwner closed 5 years ago

Gpwner commented 5 years ago

When I finished the finetune job of CLoA,I want to inference the test.tsv. But I got zero test sample:

...
...
 data_dev = task('dev').transform(trans, lazy=False)
    dataloader_dev = mx.gluon.data.DataLoader(
        data_dev,
        batch_size=dev_batch_size,
        num_workers=1,
        shuffle=False,
        batchify_fn=batchify_fn)

    data_rest = task('test').transform(trans, lazy=False)
    dataloader_test = mx.gluon.data.DataLoader(
        data_rest,
        batch_size=dev_batch_size,
        num_workers=1,
        shuffle=False,
        batchify_fn=batchify_fn)
    print('dataloader_train:{}'.format(len(dataloader_train)))
    print('dataloader_dev:{}'.format(len(dataloader_dev)))
    print('dataloader_test:{}'.format(len(dataloader_test)))

here is my output:

dataloader_train:540
dataloader_dev:131
dataloader_test:0

here is my command :

python finetune_classifier.py --task_name CoLA --epochs 4 --batch_size 16 --optimizer bertadam --gpu --lr 2e-5 --log_interval 500

Thanks for advance!

szha commented 5 years ago

I was able to reproduce the problem. Also, I noticed a warning:

gluon-nlp/scripts/bert/dataset.py:91: UserWarning: 1063 samples have been filtered out due to parsing error.
szha commented 5 years ago

Looks like the test.tsv file for CoLA only has two columns, so this index is a typo (which should have been 1): https://github.com/dmlc/gluon-nlp/blob/master/scripts/bert/dataset.py#L301

szha commented 5 years ago

cc'd @haven-jeon @eric-haibin-lin

Gpwner commented 5 years ago

I was able to reproduce the problem. Also, I noticed a warning:

gluon-nlp/scripts/bert/dataset.py:91: UserWarning: 1063 samples have been filtered out due to parsing error.

@szha So how to solve it?

Gpwner commented 5 years ago

@szha When I come to the test.tsv in CoLA dataset,I found the indice of sentense is 1,So I change the code in class COLADataset:

  def __init__(self,
                 segment='train',
                 root=os.path.join(
                     os.getenv('GLUE_DIR', 'glue_data'), task_name)):
        self._supported_segments = ['train', 'dev', 'test']
        assert segment in self._supported_segments, 'Unsupported segment: %s' % segment
        path = os.path.join(root, '%s.tsv' % segment)
        if segment in ['train', 'dev']:
            A_IDX, LABEL_IDX = 3, 1
            fields = [A_IDX, LABEL_IDX]
            super(COLADataset, self).__init__(
                path, num_discard_samples=0, fields=fields)
        elif segment == 'test':
            A_IDX = 1
            fields = [A_IDX]
            super(COLADataset, self).__init__(
                path, num_discard_samples=1, fields=fields)

But it did not work .

szha commented 5 years ago

@Gpwner this is the fix that I was alluding to in my previous comment, and in my case it fixed the problem. Are you using the class that you edited and it's still not work?

Gpwner commented 5 years ago

sing the class that you edited and it's still not work

@szha yes it is still not work.

Here is my _preprocessdata():

def preprocess_data(tokenizer, task, batch_size, dev_batch_size, max_len):
    """Data preparation function."""
    # transformation
    trans = BERTDatasetTransform(
        tokenizer,
        max_len,
        labels=task.get_labels(),
        pad=False,
        pair=task.is_pair,
        label_dtype='float32' if not task.get_labels() else 'int32')

    data_train = task('train').transform(trans, lazy=False)
    data_train_len = data_train.transform(
        lambda input_id, length, segment_id, label_id: length)

    num_samples_train = len(data_train)
    # bucket sampler
    batchify_fn = nlp.data.batchify.Tuple(
        nlp.data.batchify.Pad(axis=0), nlp.data.batchify.Stack(),
        nlp.data.batchify.Pad(axis=0),
        nlp.data.batchify.Stack(
            'float32' if not task.get_labels() else 'int32'))
    batch_sampler = nlp.data.sampler.FixedBucketSampler(
        data_train_len,
        batch_size=batch_size,
        num_buckets=10,
        ratio=0,
        shuffle=True)
    # data loaders
    dataloader_train = gluon.data.DataLoader(
        dataset=data_train,
        num_workers=1,
        batch_sampler=batch_sampler,
        batchify_fn=batchify_fn)

    data_dev = task('dev').transform(trans, lazy=False)
    dataloader_dev = mx.gluon.data.DataLoader(
        data_dev,
        batch_size=dev_batch_size,
        num_workers=1,
        shuffle=False,
        batchify_fn=batchify_fn)

    data_rest = task('test').transform(trans, lazy=False)
    dataloader_test = mx.gluon.data.DataLoader(
        data_rest,
        batch_size=dev_batch_size,
        num_workers=1,
        shuffle=False,
        batchify_fn=batchify_fn)
    print('dataloader_train:{}'.format(len(dataloader_train)))
    print('dataloader_dev:{}'.format(len(dataloader_dev)))
    print('dataloader_test:{}'.format(len(dataloader_test)))
    return dataloader_train, dataloader_dev, num_samples_train
eric-haibin-lin commented 5 years ago

@Gpwner thanks for raising this issue. It looks like the test set does not have any label, and the implementation of BERTDatasetTransform always return labels (see https://github.com/dmlc/gluon-nlp/blob/master/scripts/bert/dataset.py#L544-L551) Therefore, you cannot reuse the trans obj to read the test set. Could you try to change

        input_ids, valid_length, segment_ids = self._bert_xform(line[:-1])

        label = line[-1]
        if self.labels:  # for classification task
            label = self._label_map[label]
        label = np.array([label], dtype=self.label_dtype)

        return input_ids, valid_length, segment_ids, label

to

        input_ids, valid_length, segment_ids = self._bert_xform(line[:-1])
        if self.labels:
            label = line[-1]
            label = self._label_map[label]
            label = np.array([label], dtype=self.label_dtype)
            return input_ids, valid_length, segment_ids, label
        else:
            return input_ids, valid_length, segment_ids

and use a separate transform just for the test set:

    trans_test = BERTDatasetTransform(
        tokenizer,
        max_len,
        labels=None,
        pad=False,
        pair=task.is_pair)
Gpwner commented 5 years ago

@eric-haibin-lin I try your solution ,But It is still not work. Here are my code:

_finetuneclassify.py:

...
def preprocess_data(tokenizer, task, batch_size, dev_batch_size, max_len, root_path):
    """Data preparation function."""
    # transformation
    trans = BERTDatasetTransform(
        tokenizer,
        max_len,
        labels=task.get_labels(),
        pad=False,
        pair=task.is_pair,
        label_dtype='float32' if not task.get_labels() else 'int32')

    data_train = task('train', root=root_path).transform(trans, lazy=False)
    data_train_len = data_train.transform(
        lambda input_id, length, segment_id, label_id: length)

    num_samples_train = len(data_train)
    # bucket sampler
    batchify_fn = nlp.data.batchify.Tuple(
        nlp.data.batchify.Pad(axis=0), nlp.data.batchify.Stack(),
        nlp.data.batchify.Pad(axis=0),
        nlp.data.batchify.Stack(
            'float32' if not task.get_labels() else 'int32'))
    batch_sampler = nlp.data.sampler.FixedBucketSampler(
        data_train_len,
        batch_size=batch_size,
        num_buckets=10,
        ratio=0,
        shuffle=True)
    # data loaders
    dataloader_train = gluon.data.DataLoader(
        dataset=data_train,
        num_workers=1,
        batch_sampler=batch_sampler,
        batchify_fn=batchify_fn)

    data_dev = task('dev', root=root_path).transform(trans, lazy=False)
    dataloader_dev = mx.gluon.data.DataLoader(
        data_dev,
        batch_size=dev_batch_size,
        num_workers=1,
        shuffle=False,
        batchify_fn=batchify_fn)

    trans_test = BERTDatasetTransform(
        tokenizer,
        max_len,
        labels=None,
        pad=False,
        pair=task.is_pair)

    data_test = task('test', root=root_path).transform(trans_test, lazy=False)
    dataloader_test = mx.gluon.data.DataLoader(
        data_test,
        batch_size=dev_batch_size,
        num_workers=1,
        shuffle=False,
        batchify_fn=batchify_fn)
    print('dataloader_train:{}'.format(len(dataloader_train)))
    print('dataloader_dev:{}'.format(len(dataloader_dev)))
    print('dataloader_test:{}'.format(len(dataloader_test)))
    return dataloader_train, dataloader_dev, dataloader_test, num_samples_train

dataset.py:

@register(segment=['train', 'dev', 'test'])
class COLADataset(GLUEDataset):
    """Class for Warstdadt acceptability task

    Parameters
    ----------
    segment : str or list of str, default 'train'
        Dataset segment. Options are 'train', 'dev', 'test' or their combinations.
    root : str, default '$GLUE_DIR/CoLA
        Path to the folder which stores the CoLA dataset.
        The datset can be downloaded by the following script:
        https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e
    """
    task_name = 'CoLA'
    is_pair = False

    def __init__(self,
                 segment='train',
                 root=os.path.join(
                     os.getenv('GLUE_DIR', 'glue_data'), task_name)):
        self._supported_segments = ['train', 'dev', 'test']
        assert segment in self._supported_segments, 'Unsupported segment: %s' % segment
        path = os.path.join(root, '%s.tsv' % segment)
        if segment in ['train', 'dev']:
            A_IDX, LABEL_IDX = 3, 1
            fields = [A_IDX, LABEL_IDX]
            super(COLADataset, self).__init__(
                path, num_discard_samples=0, fields=fields)
        elif segment == 'test':
            A_IDX = 1
            fields = [A_IDX]
            super(COLADataset, self).__init__(
                path, num_discard_samples=1, fields=fields)

    @staticmethod
    def get_metric():
        """Get metrics  Matthews Correlation Coefficient"""
        return MCC(average='micro')

    @staticmethod
    def get_labels():
        """Get classification label ids of the dataset."""
        return ['0', '1']

...

class BERTDatasetTransform(object):
    """Dataset Transformation for BERT-style Sentence Classification or Regression.

    Parameters
    ----------
    tokenizer : BERTTokenizer.
        Tokenizer for the sentences.
    max_seq_length : int.
        Maximum sequence length of the sentences.
    labels : list of int , float or None. defaults None
        List of all label ids for the classification task and regressing task.
        If labels is None, the default task is regression
    pad : bool, default True
        Whether to pad the sentences to maximum length.
    pair : bool, default True
        Whether to transform sentences or sentence pairs.
    label_dtype: int32 or float32, default float32
        label_dtype = int32 for classification task
        label_dtype = float32 for regression task
    """

    def __init__(self,
                 tokenizer,
                 max_seq_length,
                 labels=None,
                 pad=True,
                 pair=True,
                 label_dtype='float32'):
        self.label_dtype = label_dtype
        self.labels = labels
        if self.labels:
            self._label_map = {}
            for (i, label) in enumerate(labels):
                self._label_map[label] = i
        self._bert_xform = BERTSentenceTransform(
            tokenizer, max_seq_length, pad=pad, pair=pair)

    def __call__(self, line):
       '''
       ...
       '''
        input_ids, valid_length, segment_ids = self._bert_xform(line[:-1])
        if self.labels:
            label = line[-1]
            label = self._label_map[label]
            label = np.array([label], dtype=self.label_dtype)
            return input_ids, valid_length, segment_ids, label
        else:
            return input_ids, valid_length, segment_ids

here is my output:

INFO:root:processing dataset...
dataloader_train:540
dataloader_dev:131
dataloader_test:0
Gpwner commented 5 years ago

@szha @eric-haibin-lin For now ,my solution is add fake labels for test data .And change init function in CoLA code in dataset.py.And it work.

eric-haibin-lin commented 5 years ago

Hi @Gpwner sorry about that. I found another bug in the Transform class that it is missing the last entry when label is not present using input_ids, valid_length, segment_ids = self._bert_xform(line[:-1]). The use of [:-1] is wrong if label is None.

The following code should work end2end:

import os
import warnings
import numpy as np
from mxnet.metric import Accuracy, F1, MCC, PearsonCorrelation, CompositeEvalMetric
from mxnet.gluon.data import Dataset
from gluonnlp.data import TSVDataset, BERTSentenceTransform
from gluonnlp.data.registry import register
from mxnet import gluon
import gluonnlp as nlp
import mxnet as mx
from dataset import COLADataset

class BERTDatasetTransform(object):
    def __init__(self,
                 tokenizer,
                 max_seq_length,
                 labels=None,
                 pad=True,
                 pair=True,
                 label_dtype='float32'):
        self.label_dtype = label_dtype
        self.labels = labels
        if self.labels:
            self._label_map = {}
            for (i, label) in enumerate(labels):
                self._label_map[label] = i
        self._bert_xform = BERTSentenceTransform(
            tokenizer, max_seq_length, pad=pad, pair=pair)

    def __call__(self, line):
        if self.labels:  # for classification task
            input_ids, valid_length, segment_ids = self._bert_xform(line[:-1])
            label = line[-1]
            label = self._label_map[label]
            label = np.array([label], dtype=self.label_dtype)
            return input_ids, valid_length, segment_ids, label
        return self._bert_xform(line)

def preprocess_data(tokenizer, task, batch_size, dev_batch_size, max_len, root_path):
    """Data preparation function."""
    # transformation
    trans = BERTDatasetTransform(
        tokenizer,
        max_len,
        labels=task.get_labels(),
        pad=False,
        pair=task.is_pair,
        label_dtype='float32' if not task.get_labels() else 'int32')

    data_train = task('train', root=root_path).transform(trans, lazy=False)
    data_train_len = data_train.transform(
        lambda input_id, length, segment_id, label_id: length)

    num_samples_train = len(data_train)
    # bucket sampler
    batchify_fn = nlp.data.batchify.Tuple(
        nlp.data.batchify.Pad(axis=0), nlp.data.batchify.Stack(),
        nlp.data.batchify.Pad(axis=0),
        nlp.data.batchify.Stack(
            'float32' if not task.get_labels() else 'int32'))
    batch_sampler = nlp.data.sampler.FixedBucketSampler(
        data_train_len,
        batch_size=batch_size,
        num_buckets=10,
        ratio=0,
        shuffle=True)

    # data loaders
    dataloader_train = gluon.data.DataLoader(
        dataset=data_train,
        num_workers=1,
        batch_sampler=batch_sampler,
        batchify_fn=batchify_fn)

    data_dev = task('dev', root=root_path).transform(trans, lazy=False)
    dataloader_dev = mx.gluon.data.DataLoader(
        data_dev,
        batch_size=dev_batch_size,
        num_workers=1,
        shuffle=False,
        batchify_fn=batchify_fn)

    trans_test = BERTDatasetTransform(
        tokenizer,
        max_len,
        labels=None,
        pad=False,
        pair=task.is_pair)

    data_test = task('test', root=root_path)
    data_test = data_test.transform(trans_test, lazy=False)
    dataloader_test = mx.gluon.data.DataLoader(
        data_test,
        batch_size=dev_batch_size,
        num_workers=1,
        shuffle=False,
        batchify_fn=batchify_fn)
    print('dataloader_train:{}'.format(len(dataloader_train)))
    print('dataloader_dev:{}'.format(len(dataloader_dev)))
    print('dataloader_test:{}'.format(len(dataloader_test)))
    return dataloader_train, dataloader_dev, dataloader_test, num_samples_train

bert, vocabulary = nlp.model.get_bert_model(
    model_name='bert_12_768_12',
    dataset_name='book_corpus_wiki_en_uncased',
    pretrained=False,
    ctx=mx.cpu(),
    use_pooler=True,
    use_decoder=False,
    use_classifier=False)

# data processing
bert_tokenizer = nlp.data.BERTTokenizer(vocabulary, lower=True)
preprocess_data(bert_tokenizer, COLADataset, 1, 1, 128, 'glue_data/CoLA')

and also you need to update the idx for COLADataset as mentioned before:

@register(segment=['train', 'dev', 'test'])
class COLADataset(GLUEDataset):
    """Class for Warstdadt acceptability task

    Parameters
    ----------
    segment : str or list of str, default 'train'
        Dataset segment. Options are 'train', 'dev', 'test' or their combinations.
    root : str, default '$GLUE_DIR/CoLA
        Path to the folder which stores the CoLA dataset.
        The datset can be downloaded by the following script:
        https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e
    """
    task_name = 'CoLA'
    is_pair = False

    def __init__(self,
                 segment='train',
                 root=os.path.join(
                     os.getenv('GLUE_DIR', 'glue_data'), task_name)):
        self._supported_segments = ['train', 'dev', 'test']
        assert segment in self._supported_segments, 'Unsupported segment: %s' % segment
        path = os.path.join(root, '%s.tsv' % segment)
        if segment in ['train', 'dev']:
            A_IDX, LABEL_IDX = 3, 1
            fields = [A_IDX, LABEL_IDX]
            super(COLADataset, self).__init__(
                path, num_discard_samples=0, fields=fields)
        elif segment == 'test':
            A_IDX = 1
            fields = [A_IDX]
            super(COLADataset, self).__init__(
                path, num_discard_samples=1, fields=fields)

This gives the following result with batch_size = 1:

dataloader_train:8551
dataloader_dev:1043
dataloader_test:1063

I'll make a PR for this fix shortly. Thanks a lot for reporting this issue!

Gpwner commented 5 years ago

@eric-haibin-lin I can see your default bert model is bert_12_768_12.I wonder whether the accuracy of bert_24_1024_16 is bertter than bert_12_768_12 ?I have try the two different model on my own dataset,but the accuracy I got are about the same.

eric-haibin-lin commented 5 years ago

Good question. I have not tried specifically bert large on CoLA. Did you try multiple seed? What accuracy do you get currently? On MRPC the performance of BERT large has large variance (which is also reported in the paper) and multiple random seeds are needed to get a good result on the dev set.

Gpwner commented 5 years ago

@eric-haibin-lin Thanks for the reply. I have try bert large on my dataset(my dataset is not CoLA,but it is similar to CoLA).But the accuracy on bert base and bert large are about the same.

eric-haibin-lin commented 5 years ago

@Gpwner I'll include the fix in PR #682 .

BTW - since you have worked on a CoLA-like dataset, would you be interested to contribute CoLA fine-tuning script command/logs to gluon-nlp, just like RTE, SST in http://gluon-nlp.mxnet.io/model_zoo/bert/index.html#bert-for-sentence-classification-on-glue-tasks ?

Gpwner commented 5 years ago

@eric-haibin-lin it seems that I have no access right to edit in the http://gluon-nlp.mxnet.io/model_zoo/bert/index.html#bert-for-sentence-classification-on-glue-tasks .

Gpwner commented 5 years ago

@eric-haibin-lin I just don't know why when I come to epoch 2 I get the nan loss:

INFO:root:processing dataset...
INFO:root:Now we are doing BERT classification training on gpu(0)!
INFO:root:[Epoch 1 Batch 100/274] loss=0.6561, lr=0.0000486, metrics:mcc:0.0000
INFO:root:[Epoch 1 Batch 200/274] loss=0.6064, lr=0.0000417, metrics:mcc:0.0000
INFO:root:validation metrics:mcc:0.0000
INFO:root:params saved in : ./output_dir/model_bert_CoLA_0.params
INFO:root:Time cost=34.00s
WARNING:py.warnings:/home/xzc/project/gluon-nlp/scripts/bert/finetune_classifier.py:435: UserWarning: nan or inf is detected. Clipping results will be undefined.
  nlp.utils.clip_grad_global_norm(params, 1)

INFO:root:[Epoch 2 Batch 100/274] loss=nan, lr=0.0000296, metrics:mcc:0.0118
INFO:root:[Epoch 2 Batch 200/274] loss=nan, lr=0.0000227, metrics:mcc:0.0068
INFO:root:validation metrics:mcc:0.0000
INFO:root:params saved in : ./output_dir/model_bert_CoLA_1.params
INFO:root:Time cost=33.75s
INFO:root:[Epoch 3 Batch 100/274] loss=nan, lr=0.0000106, metrics:mcc:0.0000
INFO:root:[Epoch 3 Batch 200/274] loss=nan, lr=0.0000037, metrics:mcc:0.0000
INFO:root:validation metrics:mcc:0.0000
INFO:root:params saved in : ./output_dir/model_bert_CoLA_2.params
INFO:root:Time cost=33.32s
eric-haibin-lin commented 5 years ago

Thanks! @Gpwner you can edit scripts/bert/index.rst, which contains the content for the website. The nan loss looks strange. Did you try a smaller learning rate? Did the nan loss always happen at epoch 2?

Gpwner commented 5 years ago

@eric-haibin-lin I try several learning rates included 2e-5,3e-5,5e-5,10e-5, but I always get the nan loss.I remember the old code didn't get nan loss.Since there is no big difference between the bert large model and the bert base model.So I have delete the old code from my disk And using the official code instead.

eric-haibin-lin commented 5 years ago

@Gpwner we found that there's a recent regression from mxnet nightly build causing NaNs. https://github.com/dmlc/gluon-nlp/issues/690 What version of mxnet are you using? If you use mxnet-cu90==1.5.0b20190407, do you still see NaNs?

Gpwner commented 5 years ago

@eric-haibin-lin I use the mxnet-cu100==1.5.0b20190427

szha commented 5 years ago

@Gpwner that version likely has the regression. The change in mxnet that caused the regression has been reverted recently. Could you try the version pip install mxnet-cu100==1.5.0b20190407 as Haibin suggests, and see if that solves the problem?

szha commented 5 years ago

Relevant issues https://github.com/dmlc/gluon-nlp/issues/690 https://github.com/apache/incubator-mxnet/issues/14864

szha commented 5 years ago

@Gpwner the related issues have been resolved. Let us know in case you still have some trouble with it.