clovaai / CutMix-PyTorch

Official Pytorch implementation of CutMix regularizer
MIT License
1.22k stars 159 forks source link

Consistency between code and paper #37

Closed shim94kr closed 3 years ago

shim94kr commented 3 years ago

Hello, I'm a student from Postech.

It was interesting to follow your work. But, I think the following code is inconsistent with the paper. https://github.com/clovaai/CutMix-PyTorch/blob/2d8eb68faff7fe4962776ad51d175c3b01a25734/train.py#L240

The loss is interpolated in the code but the labels in the paper making a single target, as described in Appendix A1.

Is there any reason to change like this, or am I misunderstood the paper?

I look forward to your response. Thank you.

SanghyukChun commented 3 years ago

Thanks for the question. We propose to mix the two targets not using a single target in the paper. As you can see (1) and Appendix A1.

Equation (1)

Appendix A1

Here, each "y" is a one-hot label (thus an example of \hat y is [0.6, 0.4, 0, 0, 0]) Note that computing the cross-entropy between the mixed target and the prediction is equivalent to mix two cross entropies with a single target, i.e.,

CE(mixed_label_between_ya_yb, prediction) = lambda * CE(ya, prediction) + (1 - lambda) CE(yb, prediction)

The above equation is easily derived from the definition of the cross entropy. There is no change after we made the first version of the CutMix algorithm, and our code is consistent with our paper. Thanks

shim94kr commented 3 years ago

I'm greatly appreciated for your prompt response!

I didn't notice the equivalence be satisfied before!

But still wondering, is there a particular issue on implementing with CE(mixed_label_between_ya_yb, prediction)? Why did you implement with the equivalent form, not the original one? At first glance, the original formula is also equally easy to implement.

Thank you!

SanghyukChun commented 3 years ago

@shim94kr Technically speaking, the CE in my comment and CE in PyTorch is not equivalent. The cross entropy implementation of PyTorch, and most ML frameworks, gets "index" not the probability itself. I.e., "target" is not a one-hot (e.g., [1, 0, 0, 0, ...]) but index value (e.g., 34) for PyTorch CE. Since the cross entropy implementation is highly optimized ("This criterion combines LogSoftmax and NLLLoss in one single class" from the document), we do not create our own CE implementation. See the official document for the details.

shim94kr commented 3 years ago

I'm now clear on this issue. Thank you!