snehaputul / AllMatch

Official implementation of our ECCV'24 paper, "Improving 3D Semi-supervised Learning by Effectively Utilizing All Unlabelled Data".
MIT License
4 stars 0 forks source link

Improving 3D Semi-supervised Learning by Effectively Utilizing All Unlabelled Data

This is a Pytorch implementation of AllMatch.

Sneha Paul* , Zachary Patterson, and Nizar Bouguila

CIISE, Concordia University, Montreal, Canada

(*) corresponding author

📣 News

Abstract

Semi-supervised learning (SSL) has shown its effectiveness in learning effective 3D representation from a small amount of labelled data while utilizing large unlabelled data. Traditional semi-supervised approaches rely on the fundamental concept of predicting pseudo-labels for unlabelled data and incorporating them into the learning process. However, we identify that the existing methods do not fully utilize all the unlabelled samples and consequently limit their potential performance. To address this issue, we propose AllMatch, a novel SSL-based 3D classification framework that effectively utilizes all the unlabelled samples. AllMatch comprises three modules: (1) an adaptive hard augmentation module that applies relatively hard augmentations to the high-confident unlabelled samples with lower loss values, thereby enhancing the contribution of such samples, (2) an inverse learning module that further improves the utilization of unlabelled data by learning what not to learn, and (3) a contrastive learning module that ensures learning from all the samples in both supervised and unsupervised settings. Comprehensive experiments on two popular 3D datasets demonstrate a performance improvement of up to 11.2\% with 1\% labelled data, surpassing the SOTA by a significant margin. Furthermore, AllMatch exhibits its efficiency in effectively leveraging all the unlabelled data, demonstrated by the fact that only 10\% of labelled data reaches nearly the same performance as fully-supervised learning with all labelled data.

Overview

Requirements

python >= 3.7

pytorch >= 1.6

h5py

scikit-learn

Datasets

ModelNet40 (https://modelnet.cs.princeton.edu/) ScanObjectNN (https://hkust-vgd.github.io/scanobjectnn/)

Example training and testing

# train
python main.py --exp_name train --perceptange 2 --num_points 1024 --dataset ModelNet40 --batch_size 24 --ema_m 0.99 --unlabeled_ratio 5 --epochs 350 --masking_epoch 50 --lr 7.5e-05 --fake_epoch 5 --u_lambda 1.0 --lambda_ce 1 --unsupcon_lambda 0.2 --supcon_lambda 0 --nl_lambda 1

Acknowledgements

This project is based on FixMatch (paper), FlexMatch (paper), FullMatch(paper), ConFid (paper). Thanks for their excellent works.