ssnl / dataset-distillation

Open-source code for paper "Dataset Distillation"
https://ssnl.github.io/dataset_distillation
MIT License
778 stars 115 forks source link

The size of tensor a (64) must match the size of tensor b (32) at non-singleton dimension 0 #54

Closed data-science-lover closed 2 years ago

data-science-lover commented 2 years ago

I noticed this problem when I wanted to test the distilled images (see basic.py):

image

The reason is the condition just above. Indeed, for a binary classification, whose output returns 32 rows (so 64 values), doing (output > 0.5).to(target.dtype).view(-1) will return a 64-value tensor (32 values of 1) but the target contains only 32 values so it will create this problem.

So, to solve this problem, just apply output.argmax(-1) even for binary classification.

ssnl commented 2 years ago

If it is binary classification, the network output is a scalar (see networks.py) if you are getting 2 logits, you are doing something wrong.

ssnl commented 2 years ago

Unless you are using AlexCifarNet, which is missing the conditional on num_classes but should be easily fixable. https://github.com/SsnL/dataset-distillation/blob/master/networks/networks.py#L53

data-science-lover commented 2 years ago

Unless you are using AlexCifarNet, which is missing the conditional on num_classes but should be easily fixable. https://github.com/SsnL/dataset-distillation/blob/master/networks/networks.py#L53

No I use my own network but indeed the output does not use the condition 1 if state.num_classes <= 2 else state.num_classes

It's a choice : indeed, my model has been trained via another algorithm (classical classification) so to be able to use it directly for data distillation, I have to put a binary output and not 1.

Is it a problem?

ssnl commented 2 years ago

It is mathematically equivalent in terms of model expressivity between (1) single output and binary cross entropy and (2) two outputs and cross entropy, assuming that the final layer is fully connected. In practice I believe that people usually use (1) for binary classification. I’m not sure which may train better.

On Fri, Jun 17, 2022 at 13:12 tkt @.***> wrote:

Unless you are using AlexCifarNet, which is missing the conditional on num_classes but should be easily fixable. https://github.com/SsnL/dataset-distillation/blob/master/networks/networks.py#L53

No I use my own network but indeed the output does not use the condition 1 if state.num_classes <= 2 else state.num_classes

It's a choice : indeed, my model has been trained via another algorithm (classical classification) so to be able to use it directly for data distillation, I have to put a binary output and not 1.

Is it a problem?

— Reply to this email directly, view it on GitHub https://github.com/SsnL/dataset-distillation/issues/54#issuecomment-1159079387, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABLJMZNGLABV3UWM4DHKPATVPSWYLANCNFSM5ZB74DKA . You are receiving this because you modified the open/close state.Message ID: @.***>

data-science-lover commented 2 years ago

Potentially, I imagine that it could be more efficient in terms of computation cost and memory if I used single output...

Thanks for your reactivity ;)