timeseriesAI / tsai

Time series Timeseries Deep Learning Machine Learning Python Pytorch fastai | State-of-the-art Deep Learning library for Time Series and Sequences in Pytorch / fastai
https://timeseriesai.github.io/tsai/
Apache License 2.0
5.16k stars 644 forks source link

Issues with TSMultiLabelClassification #533

Closed gitusrnm closed 1 year ago

gitusrnm commented 2 years ago

Hi and thanks for the great library!

I have some issues with multi-label classification. I used it with my dataset and the training was successful. But I got an error at inference.

I tried it with one sample: pred = learn.get_X_preds(X = NP.array([X[splits[1]][0]]), bs = 1)

and got:

      python3.8/site-packages/torch/_tensor.py:1051, in Tensor.__torch_function__(cls, func, types, args, kwargs)
         1048     return NotImplemented
         1050 with _C.DisableTorchFunction():
      -> 1051     ret = func(*args, **kwargs)
         1052     if func in get_default_nowrap_functions():
         1053         return ret
      RuntimeError: Boolean value of Tensor with more than one value is ambiguous

I started looking for the problem, and ended up trying to rerun the code from 01a_MultiClass_MultiLabel_TSClassification.ipynb :

from tsai.all import * 

dsid = 'ECG5000' 
X, y, splits = get_UCR_data(dsid, split_data=False)

class_map = {
    '1':['Nor'],          # N:1  - Normal
    '2':['RoT', 'Pre'],   # r:2  - R-on-T premature ventricular contraction
    '3':['PVC', 'Pre'] ,  # V:3  - Premature ventricular contraction
    '4':['SPC', 'Pre'],   # S:4  - Supraventricular premature or ectopic beat (atrial or nodal)
    '5':['Unk'],          # Q:5  - Unclassifiable beat
}
labeler = ReLabeler(class_map)
y_multi = labeler(y)

tfms  = [None, TSMultiLabelClassification()] # TSMultiLabelClassification() == [MultiCategorize(), OneHotEncode()]
batch_tfms = [TSStandardize()]
dls = get_ts_dls(X, y_multi, splits=splits, tfms=tfms, batch_tfms=batch_tfms, bs=[64, 128])

learn = ts_learner(dls, InceptionTimePlus, loss_func=BCEWithLogitsLossFlat(), cbs=[ShowGraph()])
learn.fit_one_cycle(1, lr_max=1e-3)

and got an error:

    python3.8/site-packages/torch/nn/functional.py:2980, in binary_cross_entropy_with_logits(input, target, weight, size_average, reduce, reduction, pos_weight)
       2977     reduction_enum = _Reduction.get_enum(reduction)
       2979 if not (target.size() == input.size()):
    -> 2980     raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))
       2982 return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)
    ValueError: Target size (torch.Size([384])) must be the same as input size (torch.Size([2304]))

My computer_setup():

os : Linux-5.17.11-200.fc35.x86_64-x86_64-with-glibc2.34 python : 3.8.12 tsai : 0.3.2 fastai : 2.5.6 fastcore : 1.3.27 torch : 1.10.2+cu102 device : cpu cpu cores : 12 RAM : 15.42 GB GPU memory : N/A

gitusrnm commented 2 years ago

same issue: #534

gitusrnm commented 2 years ago

Probably the same issue: #420

oguiza commented 1 year ago

Hi @gitusrnm, I'm sorry for the late reply. I've fixed an issue that impacted MultiLabel tasks. I've run this code and it works well:

from tsai.all import * 

dsid = 'ECG5000' 
X, y, splits = get_UCR_data(dsid, split_data=False)

class_map = {
    '1':['Nor'],          # N:1  - Normal
    '2':['RoT', 'Pre'],   # r:2  - R-on-T premature ventricular contraction
    '3':['PVC', 'Pre'] ,  # V:3  - Premature ventricular contraction
    '4':['SPC', 'Pre'],   # S:4  - Supraventricular premature or ectopic beat (atrial or nodal)
    '5':['Unk'],          # Q:5  - Unclassifiable beat
}
labeler = ReLabeler(class_map)
y_multi = labeler(y)

tfms  = [None, TSMultiLabelClassification()] # TSMultiLabelClassification() == [MultiCategorize(), OneHotEncode()]
batch_tfms = [TSStandardize()]
dls = get_ts_dls(X, y_multi, splits=splits, tfms=tfms, batch_tfms=batch_tfms, bs=[64, 128])

learn = ts_learner(dls, InceptionTimePlus, loss_func=BCEWithLogitsLossFlat(), cbs=[ShowGraph()])
learn.fit_one_cycle(1, lr_max=1e-3)
gitusrnm commented 1 year ago

I checked it out, and it works for me now.

Many thanks @oguiza for creating and supporting such a great library!!!