Closed lannguyen0910 closed 2 years ago
ok, I've analyzed this error. It is because MixupCutmix occasionally transforms the targets tensor into soft one-hot encoding , whereas the current FocalLoss and SmoothCE requires argmax labels.
For FocalLoss it can be easily fixed by this:
class FocalLoss(nn.Module):
def forward(self, outputs: Dict[str, Any], batch: Dict[str, Any], device: torch.device):
outputs = outputs['outputs']
targets = move_to(batch['targets'], device)
num_classes = outputs.shape[-1]
# Need to be one hot encoding
if outputs.shape != targets.shape:
targets = nn.functional.one_hot(targets, num_classes=num_classes)
targets = targets.float().squeeze()
loss = sigmoid_focal_loss(outputs, targets, self.alpha, self.gamma, self.reduction)
loss_dict = {"L": loss.item()}
return loss, loss_dict
For SmoothCE, timm has SoftTargetCrossEntropy
class which is suitable for soft one-hot encoding:
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
class SmoothCELoss(nn.Module):
def __init__(self, smoothing: float=0.1, **kwargs):
super(SmoothCELoss, self).__init__(**kwargs)
self.smooth_criterion = LabelSmoothingCrossEntropy()
self.soft_criterion = SoftTargetCrossEntropy()
def forward(self, outputs: Dict[str, Any], batch: Dict[str, Any], device: torch.device):
pred = outputs['outputs']
target = move_to(batch["targets"], device)
if pred.shape == target.shape:
loss = self.soft_criterion(pred, target)
else:
loss = self.smooth_criterion(pred, target.view(-1).contiguous())
loss_dict = {"CE": loss.item()}
return loss, loss_dict
I will make a PR and you can test it out
I encountered these errors when testing with
SmoothCELoss
andFocalLoss
, the logs are below:Focal loss
2022-03-27 11:07:46 | INFO | stdout_logger.py:log_text:28 - Number of training samples: 88814 2022-03-27 11:07:46 | INFO | stdout_logger.py:log_text:28 - Number of validation samples: 21775 2022-03-27 11:07:46 | INFO | stdout_logger.py:log_text:28 - Number of training iterations each epoch: 2775 2022-03-27 11:07:46 | INFO | stdout_logger.py:log_text:28 - Number of validation iterations each epoch: 681 2022-03-27 11:07:46 | INFO | stdout_logger.py:log_text:28 - Everything will be saved to /content/main/runs/2022-03-27_11-07-21 2022-03-27 11:07:46 | DEBUG | stdout_logger.py:log_text:34 - Saving config to /content/main/runs/2022-03-27_11-07-21/pipeline.yaml... 2022-03-27 11:07:46 | DEBUG | stdout_logger.py:log_text:34 - Saving config to /content/main/runs/2022-03-27_11-07-21/transform.yaml... 2022-03-27 11:07:46 | DEBUG | stdout_logger.py:log_text:34 - Start sanity checks 2022-03-27 11:07:47 | DEBUG | stdout_logger.py:log_text:34 - Visualizing architecture... 2022-03-27 11:07:50 | INFO | stdout_logger.py:log_text:28 - =============================EVALUATION=================================== 100% 681/681 [04:04<00:00, 2.78it/s] 2022-03-27 11:11:56 | INFO | stdout_logger.py:log_text:28 - [0|3000] || L: 0.13242 || Time: 2.7617 (it/s) 2022-03-27 11:11:56 | INFO | stdout_logger.py:log_text:28 - acc: 0.00455 | bl_acc: 0.00411 | weighted-f1: 0.00332 |
2022-03-27 11:11:56 | INFO | stdout_logger.py:log_text:28 - ========================================================================== 2022-03-27 11:11:57 | DEBUG | stdout_logger.py:log_text:34 - Visualizing model predictions... 2022-03-27 11:11:59 | DEBUG | stdout_logger.py:log_text:34 - Visualizing dataset... 2022-03-27 11:12:01 | DEBUG | stdout_logger.py:log_text:34 - Analyzing datasets... 100% 88814/88814 [12:01<00:00, 123.05it/s] 100% 21775/21775 [02:12<00:00, 163.82it/s] 2022-03-27 11:26:17 | INFO | stdout_logger.py:log_text:28 - ===========================START TRAINING================================= Traceback (most recent call last): File "/content/main/configs/classification/train.py", line 10, in
train_pipeline.fit()
File "/content/main/theseus/classification/pipeline.py", line 171, in fit
self.trainer.fit()
File "/content/main/theseus/base/trainer/base_trainer.py", line 65, in fit
self.training_epoch()
File "/content/main/theseus/base/trainer/supervised_trainer.py", line 68, in training_epoch
outputs = self.model.training_step(batch)
File "/content/main/theseus/classification/models/wrapper.py", line 34, in training_step
return self.forward(batch)
File "/content/main/theseus/classification/models/wrapper.py", line 22, in forward
loss, loss_dict = self.criterion(outputs, batch, self.device)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/content/main/theseus/classification/losses/focal_loss.py", line 21, in forward
targets = nn.functional.one_hot(targets, num_classes=num_classes)
RuntimeError: one_hot is only applicable to index tensor.
I guess it's an error from
Mixup Cutmix collator
, something withtorch.int64
.Here's the link to notebook that i've used for testing, if you want to have a look at: notebook