utterworks / fast-bert

Super easy library for BERT based NLP models
Apache License 2.0
1.85k stars 342 forks source link

batch_size_per_gpu and max_seq_length are unexpected args? #219

Closed ktfhale closed 4 years ago

ktfhale commented 4 years ago

I'm trying to construct an instance of a BertDataBunch, but the constructor doesn't seem to recognize two of the standard arguments.

from fast_bert.data import BertDataBunch
DATA_PATH = '/home/khale/FastBert/'
LABEL_PATH = '/home/khale/FastBert/'

databunch = BertDataBunch(DATA_PATH, LABEL_PATH,
                          tokenizer='bert-base-cased',
                          train_file='fastbert_train1.csv',
                          val_file='fastbert_val1.csv',
                          label_file='labels.csv',
                          text_col='text',
                          label_col='label',
                          batch_size_per_gpu=16,
                          max_seq_length=512,
                          multi_gpu=False,
                          multi_label=False,
                          model_type='bert')

Which yields-

TypeError: __init__() got an unexpected keyword argument 'batch_size_per_gpu'

If I comment out batch_size_per_gpu, then it will get a TypeError on max_seq_length. If I comment both out, and the __innit__ presumably uses the default values, then the BertDataBunch will construct fine. However, then building the BertLearner object fails, with it failing to recognize model_type as an attribute of a BertDataBunch.

~/anaconda3/envs/riskenv/lib/python3.7/site-packages/fast_bert/learner_cls.py in from_pretrained_model(dataBunch, pretrained_path, output_dir, metrics, device, logger, finetuned_wgts_path, multi_gpu, is_fp16, loss_scale, warmup_steps, fp16_opt_level, grad_accumulation_steps, multi_label, max_grad_norm, adam_epsilon, logging_steps, freeze_transformer_layers)
    131         model_state_dict = None
    132 
--> 133         model_type = dataBunch.model_type
    134 
    135         if torch.cuda.is_available():

AttributeError: 'BertDataBunch' object has no attribute 'model_type'

I'm at a loss as to what could cause this- I installed fast-bert by cloning the repo and using pip, though I've done so into a conda environment (never ideal, but fast-bert is the final package installed in building the environment, and the only package I've used pip with, so that should limit potential problems). Any help would be much appreciated.

ktfhale commented 4 years ago

This can be closed/deleted- the problem was importing BertDataBunch from data, and not data_cls, as the docs actually specify... don't I feel silly.