psellcam / LaplaceNet

A PyTorch Implementation of LaplaceNet:A Hybrid Energy-Neural Model for Deep Semi-Supervised Classification
15 stars 6 forks source link
computer-vision deep-learning semi-supervised-learning

LaplaceNet

This repository contains the code for the paper https://arxiv.org/abs/2106.04527

LaplaceNet: A Hybrid Energy-Neural Model for Deep Semi-Supervised Classification

Please cite our work if you use this code in your paper

This code follows from prior work by https://github.com/CuriousAI/mean-teacher/tree/master/pytorch and https://github.com/ahmetius/LP-DeepSSL and we give our deep thanks to these researchers.

Using this repository

This repository contains all the information you would need to recreate the experiments from our paper and use our code. After downloading and extracting this repository you need to extract the data files, set up a suitable environment and then you can run the code. We give a guide on doing so below

Data Extraction

Run these commands to extract the data for CIFAR-10/100 , starting from the base path you installed the repo to.

CIFAR-10

>> cd data-local/bin
>> ./prepare_cifar10.sh

CIFAR-100

>> cd data-local/bin
>> ./prepare_cifar100.sh

Mini-Imagenet

We use ahmetius's approach meaning thatyou can download the train and test tars from http://ptak.felk.cvut.cz/personal/toliageo/share/lpdeep/ and extract them in the following directory

>> ./data-local/images/miniimagenet/

Setting Up Environment

Requirements

From a clean conda enviroment you can perform the following commands to get a suitable enviroment

Note that faiss-gpu has some compatibaility issues with certain versions of pytorch but the combination above is certified to work

Running the program

To recreate the cifar-10 results from the main paper for 4k labels (for any label split) you can run

python main.py --dataset cifar10 --model wrn-28-8 --num-labeled 4000 --alpha 1.0 --lr 0.03 --labeled-batch-size 48 --batch-size 300 --aug-num 3 --label-split 12 --progress True

To recreate the cifar-100 results from the main paper for 10k labels (for any label split) you can run

python main.py --dataset cifar100 --model wrn-28-8 --num-labeled 10000 --alpha 0.5 --lr 0.03 --labeled-batch-size 50 --aug-num 3 --label-split 12 --progress True

To recreate the miniimagenet results from the main paper for 4k labels (for any label split) you can run

python main.py --dataset miniimagenet --model resnet18 --num-labeled 4000 --alpha 0.5 --lr 0.1 --labeled-batch-size 50 --aug-num 3 --label-split 12 --progress True

Increasing --aug-num should give better performance at a cost to computational performance.

Command line arguments

The documentation for the command line arguments can be found in config/cli.py. Here we give some extra information on the most important ones.

There are some graph based parameters which we do not offer as cli arguments, these make be changed directly but I don't recommned doing so unless you have a good reason in mind. If you want to try another graph based approach or any propogator then you would need to rewrite the one_iter_true function in db_semisuper.py and replace it with whatever you liked.

Maintenance

I will try my best to keep this github up to date. If you find a bug or want to make a comment please feel free to do so and I will try my best to resolve your problem quickly. Additionally I aim, if my PhD time allows, to add to this github with distributed training etc.