ThilinaRajapakse / pytorch-transformers-classification

Based on the Pytorch-Transformers library by HuggingFace. To be used as a starting point for employing Transformer models in text classification tasks. Contains code to easily train BERT, XLNet, RoBERTa, and XLM models for text classification.
Apache License 2.0
304 stars 97 forks source link

Can your code used for multi-class > 2 classification? #7

Closed yiwang-verisk closed 4 years ago

yiwang-verisk commented 5 years ago

Can your code used for situation that number of labels > 2?

ThilinaRajapakse commented 5 years ago

Of course! Define your own processor class that inherits from data_processor. Make sure your processor can read the data properly and that it returns the correct labels. I think that's all you need to change. Please let me know how it goes.

Edit: Turns out this isn't that straightforward.

yiwang-verisk commented 5 years ago

Do I need to change my labels like "A", "B", "C" into "0", "1", "2"? or I can keep the original str label for different categories?

ThilinaRajapakse commented 5 years ago

Any label is fine as long as it's a String. Just make sure that you include all the labels that are present in your data. For example, if you have the labels "A", "B", and "C", your DataProcessor class should look something like this.

class NewDataProcessor(DataProcessor):
    """Processor for the multiclass data sets"""

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

    def get_labels(self):
        """See base class."""
        return ["A", "B", "C"]

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            guid = "%s-%s" % (set_type, i)
            text_a = line[3]
            label = line[1]
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
        return examples

This assumes that your dataset is in a tsv file with the columns id, label, alpha, and text, in that order. The labels you use will be mapped before the examples are converted (the line of code below from the convert_examples_to_features() function), so it shouldn't matter what you use as labels, as long as all the labels are there.

label_map = {label : i for i, label in enumerate(label_list)}

yiwang-verisk commented 5 years ago

Thanks so much! I will try your code.

ThilinaRajapakse commented 5 years ago

No problem! Let me know if anything goes wrong.

amoelle commented 5 years ago

Hi there,

i tried to run a multi-class classification using your code but i ran into an error during the training step:

INFO:main:Creating features from dataset file at data/ 100%|██████████| 45/45 [00:00<00:00, 197.45it/s] INFO:main:Saving features into cached file data/cached_train_bert-base-german-cased_128_multi INFO:main: Running training INFO:main: Num examples = 45 INFO:main: Num Epochs = 1 INFO:main: Total train batch size = 8 INFO:main: Gradient Accumulation steps = 1 INFO:main: Total optimization steps = 6 Epoch: 0%| | 0/1 [00:00<?, ?it/s] Selected optimization level O1: Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are: enabled : True opt_level : O1 cast_model_type : None patch_torch_functions : True keep_batchnorm_fp32 : None master_weights : None loss_scale : dynamic Processing user overrides (additional kwargs that are not None)... After processing overrides, optimization options are: enabled : True opt_level : O1 cast_model_type : None patch_torch_functions : True keep_batchnorm_fp32 : None master_weights : None loss_scale : dynamic HBox(children=(IntProgress(value=0, description='Iteration', max=6, style=ProgressStyle(description_width='ini…


RuntimeError Traceback (most recent call last)

in 1 if args['do_train']: 2 train_dataset = load_and_cache_examples(task, tokenizer) ----> 3 global_step, tr_loss = train(train_dataset, model, tokenizer) 4 logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) in train(train_dataset, model, tokenizer) 45 outputs = model(**inputs) 46 loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc) ---> 47 print(loss) 48 print("\r%f" % loss, end='') 49 ~/miniconda3/envs/transformers/lib/python3.7/site-packages/torch/tensor.py in __repr__(self) 80 # characters to replace unicode characters with. 81 if sys.version_info > (3,): ---> 82 return torch._tensor_str._str(self) 83 else: 84 if hasattr(sys.stdout, 'encoding'): ~/miniconda3/envs/transformers/lib/python3.7/site-packages/torch/_tensor_str.py in _str(self) 298 tensor_str = _tensor_str(self.to_dense(), indent) 299 else: --> 300 tensor_str = _tensor_str(self, indent) 301 302 if self.layout != torch.strided: ~/miniconda3/envs/transformers/lib/python3.7/site-packages/torch/_tensor_str.py in _tensor_str(self, indent) 199 if self.dtype is torch.float16 or self.dtype is torch.bfloat16: 200 self = self.float() --> 201 formatter = _Formatter(get_summarized_data(self) if summarize else self) 202 return _tensor_str_with_formatter(self, indent, formatter, summarize) 203 ~/miniconda3/envs/transformers/lib/python3.7/site-packages/torch/_tensor_str.py in __init__(self, tensor) 85 86 else: ---> 87 nonzero_finite_vals = torch.masked_select(tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0)) 88 89 if nonzero_finite_vals.numel() == 0: RuntimeError: cuda runtime error (59) : device-side assert triggered at /opt/conda/conda-bld/pytorch_1565272271120/work/aten/src/THC/THCReduceAll.cuh:327 Do you have any idea what could be wrong? Kind regards!
ThilinaRajapakse commented 5 years ago

This is not as straightforward as I thought (I really should have realized this, sorry). The issue is that the pre-trained models are designed for binary classification. It doesn't seem trivial to adapt these models for multiclass classification. Changinng the config files and such breaks the loading of weights.

I think one way to do it would be to use the base model class (e.g: RobertaModel instead of RobertaForSequenceClassification) and add your own classification head.

ThilinaRajapakse commented 4 years ago

Multiclass classification is now supported on the Simple Transformers library.