joennlae / halutmatmul

Hashed Lookup Table based Matrix Multiplication (halutmatmul) - Stella Nera accelerator
MIT License
207 stars 12 forks source link

Train and test on model level #17

Open CanYing0913 opened 7 months ago

CanYing0913 commented 7 months ago

Hi, thanks for your work! I want to know the standard way to train halutmatmul on model level, as I am trying various model architectures with different datasets to examine accuracy loss from applying PQ in inference. That is, train it with same training dataset used to train the model and perform inference to collect accuracy. Currently with example.py, it only shows matrix multiplication level of training and inference. If you can point me to the right file(s) to look for that will be great, thanks!

joennlae commented 7 months ago

Hey :-)

Thank you for reaching out.

For example, the Resnet9 model can be found here: https://github.com/joennlae/halutmatmul/blob/4655152246e4a7203a9401bd6f4a71905e9cb682/src/python/models/resnet9.py#L43

So, for the default pre-training, you could run the standard training script:

python training/train.py --device cuda:0 --opt adam --model resnet9 --cifar10 --lr 0.001 --lr-scheduler cosineannealinglr --epochs 200 --amp --output-dir /path/to/model-checkpoint-output-base.pth

You can adapt the hyperparameters around here (for retraining + fine-tuning):

https://github.com/joennlae/halutmatmul/blob/4655152246e4a7203a9401bd6f4a71905e9cb682/src/python/retraining.py#L380

Then, for the layer-per-layer retraining run:

python retraining.py 0 -single -testname resnet9-lpl-0.001 -checkpoint /path/to/model-checkpoint-output-base.pth

This will generate checkpoints for each layer:

https://github.com/joennlae/halutmatmul/blob/4655152246e4a7203a9401bd6f4a71905e9cb682/src/python/retraining.py#L502

Then, to run fine-tuning, update EPOCHS to 300 here:

https://github.com/joennlae/halutmatmul/blob/4655152246e4a7203a9401bd6f4a71905e9cb682/src/python/retraining.py#L378

and run with the checkpoint of the last layer:

python retraining.py 0 -single -testname resnet9-lpl-int8-lut-fine-tune -checkpoint /path/to/output-retraining/resnet9-lpl-0.001/checkpoints/retrained_checkpoint_7_trained.pth

I hope this helps :-) I know it could be cleaned up.

CanYing0913 commented 7 months ago

Thank you for your thorough answer! I took a look at HalutConv2d here: https://github.com/joennlae/halutmatmul/blob/4655152246e4a7203a9401bd6f4a71905e9cb682/src/python/halutmatmul/modules.py#L456 It seems you replaced every Conv2d and Linear to halut version with good accuracy. I can't wait to see accuracies with larger models using your conv2d and linear :-) In your halut_matmul_forward function: https://github.com/joennlae/halutmatmul/blob/4655152246e4a7203a9401bd6f4a71905e9cb682/src/python/halutmatmul/modules.py#L141 the prototype is defined as self.P, however, I cannot find any place that assigns the prototypes, except: https://github.com/joennlae/halutmatmul/blob/4655152246e4a7203a9401bd6f4a71905e9cb682/src/python/halutmatmul/modules.py#L308-L314 but that is simply loading tensors from dictionary. how (and where) are the prototype and lut trained?

joennlae commented 7 months ago

I see how this can be confusing. The prototype is used when we run it like in the LUT-NN paper (https://arxiv.org/abs/2302.03213). Back in the day, there was no implementation of that public :-) Now there is. I would highly recommend also checking out their implementation here: https://github.com/lutnn/blink-mm

So P is basically set to null when running halut_matmul. The initial LUT is trained according to the madness paper: https://arxiv.org/abs/2106.10860 and here: https://github.com/dblalock/bolt

I heavily refactored their code and it is here: https://github.com/joennlae/halutmatmul/tree/master/src/python/halutmatmul

The learning algorithm is defined here: https://github.com/joennlae/halutmatmul/blob/4655152246e4a7203a9401bd6f4a71905e9cb682/src/python/halutmatmul/halutmatmul.py#L298

and here:

https://github.com/joennlae/halutmatmul/blob/4655152246e4a7203a9401bd6f4a71905e9cb682/src/python/halutmatmul/learn.py#L31-L46

and it is ultimately during retraining called from here:

https://github.com/joennlae/halutmatmul/blob/4655152246e4a7203a9401bd6f4a71905e9cb682/src/python/halutmatmul/model.py#L245

CanYing0913 commented 7 months ago

Thanks again! I noted that the implementations for plain data vectors and convolution models are separated. (i.e. halutmatmul and HalutConv2d/HalutLinear have totally different training procedures) Where are the prototypes and lut trained for HalutConv2d? It seems the update_lut function only performs a rounding operation on lut itself, but I could not find anywhere self.P and self.lut are trained. I am also quite confused: is HalutConv2d trained along with original model training? If so, where are the functions that learns prototypes, and construct lut at the end of iterations? My current approach is to train my PQConv2d offline (i.e. trained after entire model is trained, basically I need to iterate through dataset batches again) but they should be similar.

joennlae commented 7 months ago

I'm not entirely certain where the misunderstanding lies, but I'll do my best to clarify :-)

The lut and P are initialized using the methods described in the previous response. During training, updates are propagated back in FP16, but for the forward pass, we utilize the update_lut function to quantize the lut and then process the activation through it. This method is fairly typical.