I check your code, maybe label2one_hot_torch have some bug? Because I don't see you to convert labels from N x H x W to N x 1 x H x W .
for i, (input, label) in enumerate(train_loader):
iteration = i + epoch * epoch_size
# ========== adjust lr if necessary ===============
if (iteration + 1) in args.milestones:
state_dict = optimizer.state_dict()
for param_group in state_dict['param_groups']:
param_group['lr'] = args.lr * ((0.5) ** (args.milestones.index(iteration + 1) + 1))
optimizer.load_state_dict(state_dict)
# ========== complete data loading ================
label_1hot = label2one_hot_torch(label.to(device), C=50)
I check your code, maybe label2one_hot_torch have some bug? Because I don't see you to convert labels from N x H x W to N x 1 x H x W .