microsoft / only_train_once

OTOv1-v3, NeurIPS, ICLR, TMLR, DNN Training, Compression, Structured Pruning, Erasing Operators, CNN, Diffusion, LLM
MIT License
25 stars 5 forks source link
cnn compression diffusion erasing-operator llm pytorch structured-pruning training

Only Train Once (OTO): Automatic One-Shot DNN Training And Compression Framework

Note. Repository is under migration from tianyi/only_train_once to here.

OTO-bage autoML-bage DNN-training-bage DNN-compress-bage Operator-pruning-bage Operator-erasing-bage build-pytorchs-bage lincese-bage prs-bage

oto_overview

This repository is the (official) Pytorch implementation of Only-Train-Once (OTO). OTO is an $\color{LimeGreen}{\textbf{automatic}}$, $\color{LightCoral}{\textbf{architecture}}$ $\color{LightCoral}{\textbf{agnostic}}$ DNN $\color{Orange}{\textbf{training}}$ and $\color{Violet}{\textbf{compression}}$ (via $\color{CornflowerBlue}{\textbf{structure pruning}}$ and $\color{DarkGoldenRod}{\textbf{erasing}}$ operators) framework. By OTO, users could train a general DNN either from scratch or a pretrained checkpoint to achieve both high performance and slimmer architecture simultaneously in the one-shot manner (without fine-tuning).

Publications

Please find our series of works and bibtexs for kind citations.

oto_overview_2

In addition, we recommend our following efficient ML works.

Thanks for the interest and support from our community.

Installation

We recommend to run the framework under pytorch>=2.0. Use pip or git clone to install.

pip install only_train_once

or

git clone https://github.com/tianyic/only_train_once.git

Quick Start

We provide an example of OTO framework usage. More explained details can be found in tutorials.

Minimal usage example.

import torch
from sanity_check.backends import densenet121
from only_train_once import OTO

# Create OTO instance
model = densenet121()
dummy_input = torch.zeros(1, 3, 32, 32)
oto = OTO(model=model.cuda(), dummy_input=dummy_input.cuda())

# Create HESSO optimizer
optimizer = oto.hesso(variant='sgd', lr=0.1, target_group_sparsity=0.7)

# Train the DNN as normal via HESSO
model.train()
model.cuda()
criterion = torch.nn.CrossEntropyLoss()
for epoch in range(max_epoch):
    f_avg_val = 0.0
    for X, y in trainloader:
        X, y = X.cuda(), y.cuda()
        y_pred = model.forward(X)
        f = criterion(y_pred, y)
        optimizer.zero_grad()
        f.backward()
        optimizer.step()

# A compressed densenet will be generated. 
oto.construct_subnet(out_dir='./')

How the pruning mode in OTO works.

Sanity Check

The sanity check provides the tests for pruning mode in OTO onto various DNNs from CNN to LLM. The pass of sanity check indicates the compliance of OTO onto target DNN.

python sanity_check/sanity_check.py

Note that some tests require additional dependency. Comment off unnecessary tests. We highly recommend to proceed a sanity check over a new customized DNN for testing compliance.

Visualization

The visual_examples provides the visualization of pruning dependency graphs and erasing dependency graphs. Visualization serves as a frequently used tool for employing OTO onto new unseen DNNs if meets errors.

To do list

Welcome Contribution

We would greatly appreciate the contributions in any form, such as bug fixes, new features and new tutorials, from our open-source community.

We are humble to provide benefits for the AI community. We look forward to working with the community together to make DNN's training and compression to be more automatic and convinient.

Open for collabration.

We are open and happy for collabrations. Feel free to reach out tiachen@microsoft.com if have any interesting idea.

Legacy OTOv2 repository

The previous OTOv2 repo has been moved into legacy_branch for academic replication.

Citation

If you find the repo useful, please kindly star this repository and cite our papers:

For OTOv3 preprint
@article{chen2023otov3,
  title={OTOv3: Automatic Architecture-Agnostic Neural Network Training and Compression from Structured Pruning to Erasing Operators},
  author={Chen, Tianyi and Ding, Tianyu and Zhu, Zhihui and Chen, Zeyu and Wu, HsiangTao and Zharkov, Ilya and Liang, Luming},
  journal={arXiv preprint arXiv:2312.09411},
  year={2023}
}

For LoRAShear preprint
@article{chen2023lorashear,
  title={LoRAShear: Efficient Large Language Model Structured Pruning and Knowledge Recovery},
  author={Chen, Tianyi and Ding, Tianyu and Yadav, Badal and Zharkov, Ilya and Liang, Luming},
  journal={arXiv preprint arXiv:2310.18356},
  year={2023}
}

For AdaHSPG+ publication in TMLR (theoretical optimization paper)
@article{dai2023adahspg,
  title={An adaptive half-space projection method for stochastic optimization problems with group sparse regularization},
  author={Dai, Yutong and Chen, Tianyi and Wang, Guanyi and Robinson, Daniel P},
  journal={Transactions on machine learning research},
  year={2023}
}

For OTOv2 publication in ICLR 2023
@inproceedings{chen2023otov2,
  title={OTOv2: Automatic, Generic, User-Friendly},
  author={Chen, Tianyi and Liang, Luming and Tianyu, DING and Zhu, Zhihui and Zharkov, Ilya},
  booktitle={International Conference on Learning Representations},
  year={2023}
}

For OTOv1 publication in NeurIPS 2021
@inproceedings{chen2021otov1,
  title={Only Train Once: A One-Shot Neural Network Training And Pruning Framework},
  author={Chen, Tianyi and Ji, Bo and Tianyu, DING and Fang, Biyi and Wang, Guanyi and Zhu, Zhihui and Liang, Luming and Shi, Yixin and Yi, Sheng and Tu, Xiao},
  booktitle={Thirty-Fifth Conference on Neural Information Processing Systems},
  year={2021}
}