utterworks / fast-bert

Super easy library for BERT based NLP models
Apache License 2.0
1.85k stars 342 forks source link

MultiLabel Classification ERROR: Target size (torch.Size([176, 1])) must be the same as input size (torch.Size([16, 1])) #226

Open pvester opened 4 years ago

pvester commented 4 years ago

Hey,

The error is that my 11 labels somehow get mapped to 16*11 values for one batch, which does not match the right size of 16.

I have set up everything according to the guide with train.csv, val.csv and labels.csv. I have set multilabel = True. What can be wrong?

ssalawu commented 4 years ago

Did you progress this?

sevinjyolchuyeva commented 4 years ago

Did you progress this?

The same issue happened.

ssalawu commented 4 years ago

I narrowed the issue down to the size of my dataset and experimented with different dataset sizes till the issue disappeared.

sevinjyolchuyeva commented 4 years ago

Thanks for the response. The problem is the number of labels. When I printed the number of labels it equals 1.

num_labels = len (databunch.labels)

print(num_labels)

It should equal three in my task.

aubluce commented 3 years ago

I'm having the same issue. @sevinjyolchuyeva did you ever get yours to work? I quite understand from the comments whether this is resolved or not.

ENGSamShamsan commented 3 years ago

I wish if you can elaborate of the issue when its fixed.. I am having the same issue here: Using a target size (torch.Size([50, 128])) that is different to the input size (torch.Size([50, 21])) is deprecated. Please ensure they have the same size.

adamlamine commented 2 years ago

Thanks for the response. The problem is the number of labels. When I printed the number of labels it equals 1.

num_labels = len (databunch.labels)

print(num_labels)

It should equal three in my task.

Same here. I have 3 labels.

databunch.labels gives me: ['nan', '0.0', '1.0', '2.0']

I tried del databunch.labels[0]

after that databunch.labels gives me: ['0.0', '1.0', '2.0']

and learner.fit now works

Edit: After all, this was my error. My labels.csv file had a row too much in it.