stsievert / LeanSGD

Wide Residual Networks (WideResNets) in PyTorch
2 stars 2 forks source link

Wide Residual Networks (WideResNets) in PyTorch

WideResNets for CIFAR10/100 implemented in PyTorch. This implementation requires less GPU memory than what is required by the official Torch implementation: https://github.com/szagoruyko/wide-residual-networks.

Example:

python train.py --dataset cifar100 --layers 40 --widen-factor 4

How to generate results

The script run.py will run train.py and write summary CSVs into output/{today. The run.py script runs the commands

python train.py --qsgd=1  # use QSGD coding
python train.py --compress=1 --svd_rank=0 --svd_rescale=1  # use SVD coding
python train.py --compress=0  # use normal SGD with the param server

Note that extra arguments are added to each of these commands.

File structure

pytorch_ps_mpi/
    ps.py
    mpi_comms.py
codings/
    coding.py
    ...
train.py
run.py

Distributed training

mpirun -n 3 -hostfile hosts --map-by ppr:1:node python train.py

A quick speed test with 2 p2.xlarges and 34 layers:

And with 100 layers:

How will this change as the number of workers increase?

Acknowledgement