kmheckel / Q-S5

A Fully Quantized SSM Implementation
https://arxiv.org/abs/2406.09477
MIT License
2 stars 0 forks source link

Quantize non-SSM components of the architecture #6

Closed stevenabreu7 closed 4 months ago

stevenabreu7 commented 5 months ago

In the S5fork, we're currently only quantising the SSM. However, there are many MLPs and batch/layer norms that should also be quantized. Perhaps there could be a general flag like non_ssm_bits: int? that specifies the precision to use for these.

If we want to claim efficient inference on HW, we should have integer precision everywhere!

kmheckel commented 5 months ago
kmheckel commented 5 months ago

Question is, would it be better to pass the QuantizedOperations object around for all of the other layers to extract the necessary dot_general operation or would it be better to pass the config and then build the dot_general locally? I think the first one is maybe better but I have code written to do the latter. Passing the config might be nice since it could reduce bloat if we need to create more case-specific quantized operations for things such as Layernorm. Just pushed what I've implemented so far, haven't tested it yet.

kmheckel commented 5 months ago

Trying to run, getting this error. Will investigate.


wandb: Tracking run with wandb version 0.16.6
wandb: W&B syncing is set to `offline` in this directory.  
wandb: Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.
[*] Setting Randomness...
[*] Generating MNIST Classification Dataset
[*] Starting S5 Training on `mnist-classification` =>> Initializing...
Lambda.shape=(128,)
V.shape=(256, 128)
Vinv.shape=(128, 256)
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/legion/Code/ngsm/S5/run_qtrain.py", line 223, in <module>
    train(parser.parse_args())
  File "/home/legion/Code/ngsm/S5/s5/qtrain.py", line 180, in train
    state = create_train_state(
  File "/home/legion/Code/ngsm/S5/s5/train_helpers.py", line 131, in create_train_state
    variables = model.init({"params": init_rng,
  File "/home/legion/Code/ngsm/S5/s5/qseq_model.py", line 166, in __call__
    x = self.encoder(x, integration_timesteps)
  File "/home/legion/Code/ngsm/S5/s5/qseq_model.py", line 71, in __call__
    x = layer(x)
  File "/home/legion/Code/ngsm/S5/s5/qlayers.py", line 79, in __call__
    x = x * jax.nn.sigmoid(self.out2(x))
  File "/home/legion/.local/lib/python3.10/site-packages/flax/linen/linear.py", line 274, in __call__
    y = dot_general(
TypeError: quant_dot_for_dot.<locals>._dot() got an unexpected keyword argument 'precision'
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/legion/Code/ngsm/S5/run_qtrain.py", line 223, in <module>
    train(parser.parse_args())
  File "/home/legion/Code/ngsm/S5/s5/qtrain.py", line 180, in train
    state = create_train_state(
  File "/home/legion/Code/ngsm/S5/s5/train_helpers.py", line 131, in create_train_state
    variables = model.init({"params": init_rng,
  File "/home/legion/Code/ngsm/S5/s5/qseq_model.py", line 166, in __call__
    x = self.encoder(x, integration_timesteps)
  File "/home/legion/Code/ngsm/S5/s5/qseq_model.py", line 71, in __call__
    x = layer(x)
  File "/home/legion/Code/ngsm/S5/s5/qlayers.py", line 79, in __call__
    x = x * jax.nn.sigmoid(self.out2(x))
  File "/home/legion/.local/lib/python3.10/site-packages/flax/linen/linear.py", line 274, in __call__
    y = dot_general(
TypeError: quant_dot_for_dot.<locals>._dot() got an unexpected keyword argument 'precision'```
kmheckel commented 4 months ago

I think I've figured out the issue - I was using q_dot_maybe(non_ssm_precision), which was returning the specialized JIT vector dot product operation which is why an error is being thrown for unexpected keyword 'precision'. I'll fix this to just set the dot_general directly for the MLPs and extraneous dense layers and set the default precision to fp32. Will make/test these adjustments and then push today.

kmheckel commented 4 months ago

Fixed the issues and will push them. Instead of passing the q configs with other unnecessary information around purely to carry the non_ssm_precision, I changed it to just communicate the non_ssm_precision information as an int. This change was due to issues about non-hashable arguments. We will have to go through the architectures at some point just to probe the quantization but this gets us a lot closer.

kmheckel commented 4 months ago

This commit finishes most issues except for quantizing the norm operations: https://github.com/lindermanlab/S5/commit/31a099c2a8fe7b818e5aa324b8e50ad7f7bf8b3e

kmheckel commented 4 months ago

For LayerNorm, we should set use_bias=False to be cautious about injecting additional information.

kmheckel commented 4 months ago

I don't think we should worry about masked_meanpool since it lies at the end of the entire model pipeline, so it wouldn't break quantization in the middle of the model.

kmheckel commented 4 months ago

Closing since all major wickets have been hit for preprint.