apple / ml-cvnets

CVNets: A library for training computer vision networks
https://apple.github.io/ml-cvnets
Other
1.76k stars 225 forks source link

there is a code defect need to fix #63

Open Light-Reflection opened 1 year ago

Light-Reflection commented 1 year ago

In RANAGEAUGMENT mode, if i set ml-cvnets\examples\range_augment\classificationresnet_50.yaml --->learn_augmentation.mode to 'none' because i want to do some ablation experiment for RANAGEAUGMENT, then, line 33 in cross_entropy_with_neural_aug.py will be error, because line 240 in base_cls.py set return data into 'x', and 'x' is a tensor, you need to add x into 'out_dict ' just like:

out_dict = {
"augmented_tensor": None, "logits": x}

return out_dict.

then, it works

farzadab commented 1 year ago

Yes, if you disable learn_augmentation, you should also change both the loss to cross_entropy and stats.checkpoint_metric to top1 (instead of top1.logits).