allenai / allennlp

An open-source NLP research library, built on PyTorch.
http://www.allennlp.org
Apache License 2.0
11.74k stars 2.24k forks source link

Textual Entailment using roBERTa only predicting one category #4358

Closed ud2195 closed 4 years ago

ud2195 commented 4 years ago

Hi, I followed the exact config file given here https://github.com/allenai/allennlp-models/blob/master/training_config/pair_classification/snli_roberta.jsonnet just changed max_len in my config file and now it looks like this:-

local transformer_model = "roberta-large";
local transformer_dim = 1024;
local cls_is_last_token = false;

{
  "dataset_reader":{
    "type": "snli",
    "lazy": true,
    "tokenizer": {
      "type": "pretrained_transformer",
      "model_name": transformer_model,
      "add_special_tokens": false
    },
    "token_indexers": {
      "tokens": {
        "type": "pretrained_transformer",
        "model_name": transformer_model,
        "max_length": 40
      }
    }
  },
  "train_data_path": "/opt/ml/input/data/training/cnli_sampletrain_5L.jsonl",
  "validation_data_path": "/opt/ml/input/data/validation/cnli_sampleval_5L.jsonl",

  "model": {
    "type": "basic_classifier",
    "text_field_embedder": {
      "token_embedders": {
        "tokens": {
          "type": "pretrained_transformer",
          "model_name": transformer_model,
          "max_length": 40
        }
      }
    },
    "seq2vec_encoder": {
       "type": "cls_pooler",
       "embedding_dim": transformer_dim,
       "cls_is_last_token": cls_is_last_token
    },
    "feedforward": {
      "input_dim": transformer_dim,
      "num_layers": 1,
      "hidden_dims": transformer_dim,
      "activations": "tanh"
    },
    "dropout": 0.3,
    "namespace": "tags"
  },
  "data_loader": {
        "batch_size": 4,

        "drop_last": true,
    },
  "trainer": {
    "num_epochs": 3,
    "cuda_device" : 0,
    "validation_metric": "+accuracy",
    "learning_rate_scheduler": {
      "type": "slanted_triangular",
      "cut_frac": 0.06
    },
    "optimizer": {
      "type": "huggingface_adamw",
      "lr": 2e-5,
      "weight_decay": 0.1,
    }
  }
}

The objective is to do Textual entailment using roBERTa , but after training my model for 2 epochs the accuracy(roughly 79%) hardly changed and then upon trying my model to predict it strangely predicted the same label for all the instances present in the test data.

i have a few doubts , The model_type:basic_classifier mentioned in default config is it right ? doesnt basic_classifier implement a normal text classifier ?

Code i am using for prediction-

import pandas as pd
from allennlp_models import pair_classification
from allennlp.predictors.predictor import Predictor 
import numpy as np

data=pd.read_csv(r'/home/episourcein.episource.com/espm1854/Downloads/context_3_classes.csv')
predictor=Predictor.from_path("/home/episourcein.episource.com/espm1854/Documents/robertaTE/model.tar.gz",predictor_name="textual_entailment")

labels_dict = predictor._model.vocab.get_index_to_token_vocabulary('labels')

def get_labels(hypothesis, premise):
    pred = predictor.predict(
      hypothesis=hypothesis,
      premise=premise
    )

    label = labels_dict[np.argmax(pred['probs'])]
    return label

data['predictions']= data.apply(lambda x: get_labels(x['hypothesis'],x['sentence']),  axis=1)

if the model_type here is wrong then what model_type should i specify inplace of basic_classifier to do textual entailment with roBERTa? A sample config file for doing entailment with roBERTa would really be helpful Any help will be appreciated ! @epwalsh @matt-gardner

dirkgr commented 4 years ago

I think you just need to try some more hyperparameters. In particular, try a lower learning rate (I'm a fan of 1e-5.), and set "correct_bias": true in the parameters for the optimizer.

Is it possible that your training set is really unbalanced, and that's why you get high accuracy with only one kind of output?

github-actions[bot] commented 4 years ago

This issue is being closed due to lack of activity. If you think it still needs to be addressed, please comment on this thread 👇