utterworks / fast-bert

Super easy library for BERT based NLP models
Apache License 2.0
1.86k stars 341 forks source link

Target Batch Size Doubling: ValueError: Target size (torch.Size([16, 5])) must be the same as input size (torch.Size([8, 5])) #291

Closed RaedShabbir closed 3 years ago

RaedShabbir commented 3 years ago

Strange error I keep getting when my datasets size becomes smaller, any help would be appreciated. For some reason target batch size keeps doubling.

Traceback


---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-17-fc2a900ad6bc> in <module>
----> 1 learner.lr_find(start_lr=1e-5,optimizer_type='lamb')

python3.8/site-packages/fast_bert/learner_cls.py in lr_find(self, start_lr, end_lr, use_val_loss, optimizer_type, num_iter, step_mode, smooth_f, diverge_th)
    672         for iteration in tqdm(range(num_iter)):
    673             # train on batch and retrieve loss
--> 674             loss = self._train_batch(train_iter)
    675             if use_val_loss:
    676                 loss = self.validate(quiet=True, loss_only=True)["loss"]

python3.8/site-packages/fast_bert/learner_cls.py in _train_batch(self, train_iter)
    719             if self.is_fp16:
    720                 with autocast():
--> 721                     outputs = self.model(**inputs)
    722             else:
    723                 outputs = self.model(**inputs)

python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

python3.8/site-packages/fast_bert/modeling.py in forward(self, input_ids, attention_mask, labels, head_mask)
     71             loss_fct = BCEWithLogitsLoss(weight=self.weight, pos_weight=self.pos_weight)
     72 
---> 73             loss = loss_fct(
     74                 logits.view(-1, self.num_labels), labels.view(-1, self.num_labels)
     75             )

python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

python3.8/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
    712         assert self.weight is None or isinstance(self.weight, Tensor)
    713         assert self.pos_weight is None or isinstance(self.pos_weight, Tensor)
--> 714         return F.binary_cross_entropy_with_logits(input, target,
    715                                                   self.weight,
    716                                                   pos_weight=self.pos_weight,

python3.8/site-packages/torch/nn/functional.py in binary_cross_entropy_with_logits(input, target, weight, size_average, reduce, reduction, pos_weight)
   2825 
   2826     if not (target.size() == input.size()):
-> 2827         raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))
   2828 
   2829     return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)

ValueError: Target size (torch.Size([16, 5])) must be the same as input size (torch.Size([8, 5]))

​```
RaedShabbir commented 3 years ago

Possible Cause: I changed the number of labels in my data without any changes to the file paths and model name, so when creating a Databunch the model tries to load the cached features from previous runs, which cause the above error as batch size may have changed? The databunch has the below output.

Loading features from cached file <path to cache> 

Solution: Fastest way to fix this was to just delete the cache, but using a better naming convention would probably avoid this problem altogether.