JonasGeiping / cramming

Cramming the training of a (BERT-type) language model into limited compute.
MIT License
1.3k stars 100 forks source link

Evaluation failed on MNLI and STSB Datasets for Last1.13release #29

Closed Labyrintbs closed 1 year ago

Labyrintbs commented 1 year ago

I followed instructions to replicate the Last1.13release using the corrseponding version's README.md, i.e.

python pretrain.py name=amp_b4096_c5_o3_final arch=bert-c5 train=bert-o3 train.batch_size=4096 data=bookcorpus-wikipedia

python eval.py eval=GLUE_sane name=amp_b4096_c5_o3_final eval.checkpoint=latest impl.microbatch_size=16 impl.shuffle_in_dataloader=True

The pretraining worked fine except for loss explosion using the default lr_scheduler budget-triangle2 in bert-o3.yaml, so i just changed to budget-one-cycle according to the report of schedulers on the paper, since these two have similar behaviors for pretraining loss decay. Anyway the pretraining finnaly achieved a loss of 1.8282 in a RTX2080Ti for a single day, equivalent to the result reported in paper. But for evaluation, problem came out for the downstream tasks diffrent of 2 classifications, like 3 classification for MNLI and 1 classification for STSB. For MNLI, errors happened like RuntimeError: CUDA error: device-side assert triggered or IndexError: Target 2 is out of bounds if putting the model on CPU and to looking for further infos. For STSB, errors happened like loss evaluation error happens, Target size (torch.Size([16])) must be the same as input size (torch.Size([16, 2]))

I checked the code carefully, and found the problem comes one line from the 'class ScriptableLMForSequenceClassification(PreTrainedModel)'

config.arch['num_labels'] = config.num_labels

(https://github.com/JonasGeiping/cramming/blob/4a5e3008a5ec05ed68f9d096e4875f8dddadcf81/cramming/architectures/scriptable_bert.py#L229)

which is initialized in downstream task function (https://github.com/JonasGeiping/cramming/blob/4a5e3008a5ec05ed68f9d096e4875f8dddadcf81/cramming/architectures/scriptable_bert.py#L24C1-L35C17)

def construct_scriptable_bert(cfg_arch, vocab_size, downstream_classes=None):
    """See the config file for details on what is possible."""
    cfg_arch.embedding.vocab_size = vocab_size
    cfg_arch.num_labels = downstream_classes

    config = crammedBertConfig(OmegaConf.to_container(cfg_arch, resolve=True))
    if downstream_classes is None:
        model = ScriptableLMForPreTraining(config)
    else:
        model = ScriptableLMForSequenceClassification(config)

    return model

class crammedBertConfig(PretrainedConfig):
    model_type = "crammedBERT"

    def __init__(self, cfg_arch_container: dict = {}, **kwargs):
        self.arch = cfg_arch_container
        super().__init__(**kwargs)

All the modification here work and I realized the args passed to ScriptableLMForSequenceClassification worked as arch attribute of crammedBertConfig class inherited from transformers lib's basic class PretrainedConfig.

class ScriptableLMForSequenceClassification(PreTrainedModel):
    """Classification head and pooler."""

    config_class = crammedBertConfig

    def __init__(self, config):
        super().__init__(config)
        config.arch['num_labels'] = config.num_labels
        self.cfg = OmegaConf.create(config.arch)  # this could be nicer ...
        self.encoder = ScriptableLM(config)

        self.pooler = PoolingComponent(self.cfg.classification_head, self.cfg.hidden_size)
        self.head = torch.nn.Linear(self.cfg.classification_head.head_dim, self.cfg.num_labels)

However, this line of code config.arch['num_labels'] = config.num_labels just rewrites the final classification number to 2 since the default PretrainedConfig sets its attribute num_labels to 2.

I commented this line of code and it seems work fine.

As this released version is fairly old to the newest Torch2.1, I think it's meaningless to open a pr so I leave a issue here in case someone encounters the same problem of me :)

JonasGeiping commented 1 year ago

Hi, congrats on training your model! I think this is the same problem as https://github.com/JonasGeiping/cramming/issues/24? Yeah this should be fixed in the new version.

Labyrintbs commented 1 year ago

Hi, congrats on training your model! I think this is the same problem as #24? Yeah this should be fixed in the new version.

Oh thanks! Exactly the same problem, i didn't notice this issue before ;)

JonasGeiping commented 1 year ago

Great, closing this for now.