Closed stevenabreu7 closed 4 months ago
non_ssm_precision
dot_general to QuantizationConfigmasked_meanpool
https://github.com/stevenabreu7/S5/blob/c4a22d830568ada26b30ff2be643d8b69ca04002/s5/seq_model.py#L70Question is, would it be better to pass the QuantizedOperations object around for all of the other layers to extract the necessary dot_general
operation or would it be better to pass the config and then build the dot_general
locally? I think the first one is maybe better but I have code written to do the latter. Passing the config might be nice since it could reduce bloat if we need to create more case-specific quantized operations for things such as Layernorm. Just pushed what I've implemented so far, haven't tested it yet.
Trying to run, getting this error. Will investigate.
wandb: Tracking run with wandb version 0.16.6
wandb: W&B syncing is set to `offline` in this directory.
wandb: Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.
[*] Setting Randomness...
[*] Generating MNIST Classification Dataset
[*] Starting S5 Training on `mnist-classification` =>> Initializing...
Lambda.shape=(128,)
V.shape=(256, 128)
Vinv.shape=(128, 256)
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/legion/Code/ngsm/S5/run_qtrain.py", line 223, in <module>
train(parser.parse_args())
File "/home/legion/Code/ngsm/S5/s5/qtrain.py", line 180, in train
state = create_train_state(
File "/home/legion/Code/ngsm/S5/s5/train_helpers.py", line 131, in create_train_state
variables = model.init({"params": init_rng,
File "/home/legion/Code/ngsm/S5/s5/qseq_model.py", line 166, in __call__
x = self.encoder(x, integration_timesteps)
File "/home/legion/Code/ngsm/S5/s5/qseq_model.py", line 71, in __call__
x = layer(x)
File "/home/legion/Code/ngsm/S5/s5/qlayers.py", line 79, in __call__
x = x * jax.nn.sigmoid(self.out2(x))
File "/home/legion/.local/lib/python3.10/site-packages/flax/linen/linear.py", line 274, in __call__
y = dot_general(
TypeError: quant_dot_for_dot.<locals>._dot() got an unexpected keyword argument 'precision'
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/legion/Code/ngsm/S5/run_qtrain.py", line 223, in <module>
train(parser.parse_args())
File "/home/legion/Code/ngsm/S5/s5/qtrain.py", line 180, in train
state = create_train_state(
File "/home/legion/Code/ngsm/S5/s5/train_helpers.py", line 131, in create_train_state
variables = model.init({"params": init_rng,
File "/home/legion/Code/ngsm/S5/s5/qseq_model.py", line 166, in __call__
x = self.encoder(x, integration_timesteps)
File "/home/legion/Code/ngsm/S5/s5/qseq_model.py", line 71, in __call__
x = layer(x)
File "/home/legion/Code/ngsm/S5/s5/qlayers.py", line 79, in __call__
x = x * jax.nn.sigmoid(self.out2(x))
File "/home/legion/.local/lib/python3.10/site-packages/flax/linen/linear.py", line 274, in __call__
y = dot_general(
TypeError: quant_dot_for_dot.<locals>._dot() got an unexpected keyword argument 'precision'```
I think I've figured out the issue - I was using q_dot_maybe(non_ssm_precision), which was returning the specialized JIT vector dot product operation which is why an error is being thrown for unexpected keyword 'precision'
. I'll fix this to just set the dot_general directly for the MLPs and extraneous dense layers and set the default precision to fp32. Will make/test these adjustments and then push today.
Fixed the issues and will push them. Instead of passing the q configs with other unnecessary information around purely to carry the non_ssm_precision
, I changed it to just communicate the non_ssm_precision
information as an int. This change was due to issues about non-hashable arguments. We will have to go through the architectures at some point just to probe the quantization but this gets us a lot closer.
This commit finishes most issues except for quantizing the norm operations: https://github.com/lindermanlab/S5/commit/31a099c2a8fe7b818e5aa324b8e50ad7f7bf8b3e
For LayerNorm, we should set use_bias=False
to be cautious about injecting additional information.
I don't think we should worry about masked_meanpool
since it lies at the end of the entire model pipeline, so it wouldn't break quantization in the middle of the model.
Closing since all major wickets have been hit for preprint.
In the S5fork, we're currently only quantising the SSM. However, there are many MLPs and batch/layer norms that should also be quantized. Perhaps there could be a general flag like
non_ssm_bits: int?
that specifies the precision to use for these.If we want to claim efficient inference on HW, we should have integer precision everywhere!