Open HoangHoang1408 opened 4 days ago
transformer_model = to_quantize_modules['transformer'] layers = transformer_model.transformer_blocks
inputs = {} outputs = {} index = {"i": 0} class Catcher(nn.Module): def init(self, module, layer_index): super().init() self.module = module self.layer_index = layer_index
def forward(self, inp, **kwargs):
inputs[index['i']] = {
"input": inp
}
for key, val in kwargs.items():
inputs[index['i']][key] = val
index['i'] += 1
raise ValueError
layers[0] = Catcher(layers[0], 0)
for d in data[:args['nsamples']]: try: model_pipeline(d, num_inference_steps=2) except Exception as e: pass
layers[0] = layers[0].module
if args['use_vq']: QClass = lambda: VQQuantizer( vq_dim=args['vq_dim'], columns_per_group=args['columns_per_group'], vq_scaling_blocksize=args['vq_scaling_blocksize'], vq_scaling_norm=args['vq_scaling_norm'], vq_scaling_n_bits=args['vq_scaling_n_bits'], vq_scaling_domain=args['vq_scaling_domain'], kmeans_init_method=args['kmeans_init_method'], assignment_chunk_size=args['assignment_chunk_size'], kmeans_iters=args['kmeans_iters'], codebook_bitwidth=args['codebook_bitwidth'], quantize_per_codebook=args['quantize_per_codebook'], quantize_during_kmeans=args['quantize_during_kmeans'], n_subsample=args['n_subsample'], ) else: QClass = Quantizer
transformer_quantizers = {} for layer_index in range(len(layers)): layer = layers[layer_index] full = find_layers(layer) if args["true_sequential"]: sequential = [ ["attn1.to_k", "attn1.to_v", "attn1.to_q"], ["attn1.to_out.0"], ["attn2.to_k", "attn2.to_v", "attn2.to_q"], ["attn2.to_out.0"], ['ff.net.0.proj'], ['ff.net.2'], ] else: sequential = [[k for k in list(full.keys()) if "block_sparse_moe.gate" not in k]]
for names in sequential:
subset = {n: full[n] for n in names}
gptq = {}
for name in subset:
gptq[name] = GPTQ(subset[name])
gptq[name].quantizer = QClass()
gptq[name].quantizer.configure(args["wbits"], perchannel=True, sym=args["sym"], mse=False)
def add_batch(name):
def tmp(_, inp, out):
gptq[name].add_batch(inp[0].data, out.data)
return tmp
handles = []
for name in subset:
handles.append(subset[name].register_forward_hook(add_batch(name)))
for j in range(args["nsamples"]):
outputs[j] = layer(inputs[j]['input'], **{k: v for k, v in inputs[j].items() if k != 'input'})
for h in handles:
h.remove()
for name in subset:
print(layer_index, name)
print("Quantizing ...")
gptq[name].fasterquant(
percdamp=args["percdamp"],
groupsize=args["groupsize"],
actorder=args["act_order"],
static_groups=args["static_groups"],
include_m_step=args["include_m_step"],
use_vq=args["use_vq"],
svd_rank=args["svd_rank"],
hessian_weighted_lookups=args["hessian_weighted_lookups"],
only_init_kmeans=args["only_init_kmeans"],
)
transformer_quantizers["model.layers.%d.%s" % (layer_index, name)] = gptq[name].quantizer
gptq[name].free()
for j in range(args["nsamples"]):
outputs[j] = layer(inputs[j]['input'], **{k: v for k, v in inputs[j].items() if k != 'input'})
layers[layer_index] = layer.cpu()
del layer
del gptq
torch.cuda.empty_cache()
for k, v in outputs.items():
inputs[k]['input'] = v
oke