Closed data-science-lover closed 2 years ago
If it is binary classification, the network output is a scalar (see networks.py) if you are getting 2 logits, you are doing something wrong.
Unless you are using AlexCifarNet, which is missing the conditional on num_classes but should be easily fixable. https://github.com/SsnL/dataset-distillation/blob/master/networks/networks.py#L53
Unless you are using AlexCifarNet, which is missing the conditional on num_classes but should be easily fixable. https://github.com/SsnL/dataset-distillation/blob/master/networks/networks.py#L53
No I use my own network but indeed the output does not use the condition 1 if state.num_classes <= 2 else state.num_classes
It's a choice : indeed, my model has been trained via another algorithm (classical classification) so to be able to use it directly for data distillation, I have to put a binary output and not 1.
Is it a problem?
It is mathematically equivalent in terms of model expressivity between (1) single output and binary cross entropy and (2) two outputs and cross entropy, assuming that the final layer is fully connected. In practice I believe that people usually use (1) for binary classification. I’m not sure which may train better.
On Fri, Jun 17, 2022 at 13:12 tkt @.***> wrote:
Unless you are using AlexCifarNet, which is missing the conditional on num_classes but should be easily fixable. https://github.com/SsnL/dataset-distillation/blob/master/networks/networks.py#L53
No I use my own network but indeed the output does not use the condition 1 if state.num_classes <= 2 else state.num_classes
It's a choice : indeed, my model has been trained via another algorithm (classical classification) so to be able to use it directly for data distillation, I have to put a binary output and not 1.
Is it a problem?
— Reply to this email directly, view it on GitHub https://github.com/SsnL/dataset-distillation/issues/54#issuecomment-1159079387, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABLJMZNGLABV3UWM4DHKPATVPSWYLANCNFSM5ZB74DKA . You are receiving this because you modified the open/close state.Message ID: @.***>
Potentially, I imagine that it could be more efficient in terms of computation cost and memory if I used single output...
Thanks for your reactivity ;)
I noticed this problem when I wanted to test the distilled images (see basic.py):
The reason is the condition just above. Indeed, for a binary classification, whose output returns 32 rows (so 64 values), doing (output > 0.5).to(target.dtype).view(-1) will return a 64-value tensor (32 values of 1) but the target contains only 32 values so it will create this problem.
So, to solve this problem, just apply output.argmax(-1) even for binary classification.