YGZWQZD / LAMDA-SSL

30 Semi-Supervised Learning Algorithms
MIT License
181 stars 16 forks source link

How to run my three-category tabular data #4

Open smntjugithub opened 1 year ago

smntjugithub commented 1 year ago

Thanks for the great work, I need your help.

If I want to solve the three-category problem, which code should I modify. For example: if there are three categories in the BreastCancer dataset. Because I found that when I didn't modify any code, the confusion matrix only made predictions for the first two classes.

Result/Co_Training_BreastCancer.txt: accuracy 0.324468085106383 precision 0.2598727091480715 Recall 0.3464646464646464 F1 0.2306878306878307 Confusion_matrix [[0.16666667 0.83333333 0. ] [0.12727273 0.87272727 0. ] [0.12727273 0.87272727 0. ]]

YGZWQZD commented 1 year ago

Hi. This is because the original co-training algorithm doesn't support multiclass classification. We have just extended the original algorithm. You can clone the new version code for your need. Thank you for your question. If you have any questions, please continue to contact us.

smntjugithub commented 1 year ago

Thank you very much for your prompt help. I haven't tried it yet, I want to ask if all classification algorithms support multi-class, or only co-training algorithm?

smntjugithub commented 1 year ago

I re-tried the latest package. But it seems like only the first class is predicted. Result/Co_Training_BreastCancer.txt: accuracy 0.4148936170212766 precision 0.13829787234042554 Recall 0.3333333333333333 F1 0.1954887218045113 AUC 0.6487264250900614 Confusion_matrix [[1. 0. 0.] [1. 0. 0.] [1. 0. 0.]]

smntjugithub commented 1 year ago

breast_cancer.csv Here is my raw data. It contains 625 data, with 343 features, corresponding to 3 classes I believe the data format is correct.

YGZWQZD commented 1 year ago

Hi. You can try the lastest version (https://github.com/YGZWQZD/LAMDA-SSL/tree/master/LAMDA_SSL/Algorithm/Classification/Co_Training.py) which hasn't been released yet. It supports multiclass classification. In addition, TSVM, LapSVM and SemiBoost support only binary classification.

smntjugithub commented 1 year ago

thanks for your help. I re-tried the lastest version (https://github.com/YGZWQZD/LAMDASSL/tree/master/LAMDA_SSL/Algorithm/Classification/Co_Training.py). But it seems like only the first class is predicted. as the confusion matrix is
[[1. 0. 0.] [1. 0. 0.] [1. 0. 0.]]

Result/Co_Training_BreastCancer.txt: accuracy 0.4148936170212766 precision 0.13829787234042554 Recall 0.3333333333333333 F1 0.1954887218045113 AUC 0.5935805608532881 Confusion_matrix [[1. 0. 0.] [1. 0. 0.] [1. 0. 0.]]

YGZWQZD commented 1 year ago

Could you try raising the parameter 'threshold'? You can send your code to my email(jialh@lamda.nju.edu.cn) if it's not private.

Zebin-Li commented 1 year ago

Hello,

Thank you for the wonderful work! I have a similar question is that, can I use deep SSL methods for tabular data? For example, using FixMatch for the BreastCancer data? Thank you very much!

Swww-w commented 1 year ago

你好

感谢您的出色工作!我有一个类似的问题是,我可以对表格数据使用深度 SSL 方法吗?例如,对乳腺癌数据使用 FixMatch? 谢谢!

你好 ,请问你的问题解决了吗, 我也想尝试对表格数据用深度ssl算法, 比如flexmatch,感谢

Zebin-Li commented 1 year ago

Hi @Swww-w, Yes, I have implemented deep SSL on tabular data and please check the following for the reference. Thank you!

    labeled_X = labeled_SSL_X.astype(float)
    labeled_y = labeled_SSL_y
    unlabeled_X = unlabeled_SSL_X.astype(float)
    test_X = testing_set_X
    test_y = testing_set_y

    transform = ToTensor()
    labeled_dataset = LabeledDataset(transform=transform)
    unlabeled_dataset = UnlabeledDataset(transform=transform)
    test_dataset = UnlabeledDataset(transform=transform)

    augmentation = Noise(noise_level=0.01)
    network = MLPCLS(hidden_dim=[64, 32, 8], activations=[nn.ReLU(), nn.ReLU(), nn.ReLU()], dim_in=labeled_X.shape[-1])

    # FixMatch
    model_FixMatch = FixMatch(
        labeled_dataset=labeled_dataset, unlabeled_dataset=unlabeled_dataset, test_dataset=test_dataset,
        augmentation=augmentation, threshold=0.95, lambda_u=1.0, network=network, mu=7, T=0.4, epoch=1,
        num_it_epoch=400, num_it_total=400)

    model_FixMatch.fit(X=labeled_X, y=labeled_y, unlabeled_X=unlabeled_X)
Swww-w commented 1 year ago

thank you very very much

---Original--- From: "Zebin @.> Date: Mon, Aug 7, 2023 00:44 AM To: @.>; Cc: @.**@.>; Subject: Re: [YGZWQZD/LAMDA-SSL] How to run my three-category tabular data(Issue #4)

Hi @Swww-w, Yes, I have implemented deep SSL on tabular data and please check the following for the reference. Thank you! labeled_X = labeled_SSL_X.astype(float) labeled_y = labeled_SSL_y unlabeled_X = unlabeled_SSL_X.astype(float) test_X = testing_set_X test_y = testing_set_y transform = ToTensor() labeled_dataset = LabeledDataset(transform=transform) unlabeled_dataset = UnlabeledDataset(transform=transform) test_dataset = UnlabeledDataset(transform=transform) augmentation = Noise(noise_level=0.01) network = MLPCLS(hidden_dim=[64, 32, 8], activations=[nn.ReLU(), nn.ReLU(), nn.ReLU()], dim_in=labeled_X.shape[-1]) # FixMatch model_FixMatch = FixMatch( labeled_dataset=labeled_dataset, unlabeled_dataset=unlabeled_dataset, test_dataset=test_dataset, augmentation=augmentation, threshold=0.95, lambda_u=1.0, network=network, mu=7, T=0.4, epoch=1, num_it_epoch=400, num_it_total=400) model_FixMatch.fit(X=labeled_X, y=labeled_y, unlabeled_X=unlabeled_X)
— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you were mentioned.Message ID: @.***>

Swww-w commented 1 year ago

Hello,

Thank you for the wonderful work! I have a similar question is that, can I use deep SSL methods for tabular data? For example, using FixMatch for the BreastCancer data? Thank you very much!

大佬,很抱歉真么晚了给你发消息,我自己的表格数据(近红外光谱数据)用这个fixmatch老是报错,我改了一天了也没能改出来,然后也没啥人能帮我的,所以想问一下您能把您的数据和代码发给我吗,我保证这是最后一次给您发消息了,不会在打扰您了。

Zebin-Li commented 1 year ago

Hello, Thank you for the wonderful work! I have a similar question is that, can I use deep SSL methods for tabular data? For example, using FixMatch for the BreastCancer data? Thank you very much!

大佬,很抱歉真么晚了给你发消息,我自己的表格数据(近红外光谱数据)用这个fixmatch老是报错,我改了一天了也没能改出来,然后也没啥人能帮我的,所以想问一下您能把您的数据和代码发给我吗,我保证这是最后一次给您发消息了,不会在打扰您了。

Hi, I replied you via email, have you received it?

Swww-w commented 1 year ago

大佬 ,俺好像没收到, 这是我的QQ邮箱1927180004@qq.com,烦请您在发一次吧,感谢!

Zebin-Li commented 1 year ago

大佬 ,俺好像没收到, 这是我的QQ邮箱1927180004@qq.com,烦请您在发一次吧,感谢!

sent