Open CanYing0913 opened 7 months ago
Hey :-)
Thank you for reaching out.
For example, the Resnet9
model can be found here: https://github.com/joennlae/halutmatmul/blob/4655152246e4a7203a9401bd6f4a71905e9cb682/src/python/models/resnet9.py#L43
So, for the default pre-training, you could run the standard training script:
python training/train.py --device cuda:0 --opt adam --model resnet9 --cifar10 --lr 0.001 --lr-scheduler cosineannealinglr --epochs 200 --amp --output-dir /path/to/model-checkpoint-output-base.pth
You can adapt the hyperparameters around here (for retraining + fine-tuning):
Then, for the layer-per-layer retraining run:
python retraining.py 0 -single -testname resnet9-lpl-0.001 -checkpoint /path/to/model-checkpoint-output-base.pth
This will generate checkpoints for each layer:
Then, to run fine-tuning, update EPOCHS to 300 here:
and run with the checkpoint of the last layer:
python retraining.py 0 -single -testname resnet9-lpl-int8-lut-fine-tune -checkpoint /path/to/output-retraining/resnet9-lpl-0.001/checkpoints/retrained_checkpoint_7_trained.pth
I hope this helps :-) I know it could be cleaned up.
Thank you for your thorough answer! I took a look at HalutConv2d
here:
https://github.com/joennlae/halutmatmul/blob/4655152246e4a7203a9401bd6f4a71905e9cb682/src/python/halutmatmul/modules.py#L456 It seems you replaced every Conv2d
and Linear
to halut
version with good accuracy. I can't wait to see accuracies with larger models using your conv2d and linear :-)
In your halut_matmul_forward
function: https://github.com/joennlae/halutmatmul/blob/4655152246e4a7203a9401bd6f4a71905e9cb682/src/python/halutmatmul/modules.py#L141 the prototype is defined as self.P
, however, I cannot find any place that assigns the prototypes, except: https://github.com/joennlae/halutmatmul/blob/4655152246e4a7203a9401bd6f4a71905e9cb682/src/python/halutmatmul/modules.py#L308-L314 but that is simply loading tensors from dictionary. how (and where) are the prototype
and lut
trained?
I see how this can be confusing. The prototype is used when we run it like in the LUT-NN paper (https://arxiv.org/abs/2302.03213). Back in the day, there was no implementation of that public :-) Now there is. I would highly recommend also checking out their implementation here: https://github.com/lutnn/blink-mm
So P is basically set to null when running halut_matmul
. The initial LUT is trained according to the madness paper: https://arxiv.org/abs/2106.10860 and here: https://github.com/dblalock/bolt
I heavily refactored their code and it is here: https://github.com/joennlae/halutmatmul/tree/master/src/python/halutmatmul
The learning algorithm is defined here: https://github.com/joennlae/halutmatmul/blob/4655152246e4a7203a9401bd6f4a71905e9cb682/src/python/halutmatmul/halutmatmul.py#L298
and here:
and it is ultimately during retraining called from here:
Thanks again! I noted that the implementations for plain data vectors and convolution models are separated. (i.e. halutmatmul
and HalutConv2d
/HalutLinear
have totally different training procedures) Where are the prototypes and lut trained for HalutConv2d
? It seems the update_lut
function only performs a rounding operation on lut
itself, but I could not find anywhere self.P and self.lut are trained.
I am also quite confused: is HalutConv2d
trained along with original model training? If so, where are the functions that learns prototypes, and construct lut
at the end of iterations? My current approach is to train my PQConv2d
offline (i.e. trained after entire model is trained, basically I need to iterate through dataset batches again) but they should be similar.
I'm not entirely certain where the misunderstanding lies, but I'll do my best to clarify :-)
The lut
and P
are initialized using the methods described in the previous response. During training, updates are propagated back in FP16, but for the forward pass, we utilize the update_lut
function to quantize the lut and then process the activation through it. This method is fairly typical.
Hi, thanks for your work! I want to know the standard way to train
halutmatmul
on model level, as I am trying various model architectures with different datasets to examine accuracy loss from applying PQ in inference. That is, train it with same training dataset used to train the model and perform inference to collect accuracy. Currently withexample.py
, it only shows matrix multiplication level of training and inference. If you can point me to the right file(s) to look for that will be great, thanks!