HoangHoang1408 / temp

0 stars 0 forks source link

pixart_alpha_transformer #27

Open HoangHoang1408 opened 4 days ago

HoangHoang1408 commented 4 days ago

oke

HoangHoang1408 commented 4 days ago

Quantize Transformer

transformer_model = to_quantize_modules['transformer'] layers = transformer_model.transformer_blocks

HoangHoang1408 commented 4 days ago

inputs = {} outputs = {} index = {"i": 0} class Catcher(nn.Module): def init(self, module, layer_index): super().init() self.module = module self.layer_index = layer_index

def forward(self, inp, **kwargs):
    inputs[index['i']] = {
        "input": inp
    }
    for key, val in kwargs.items():
        inputs[index['i']][key] = val
    index['i'] += 1
    raise ValueError

layers[0] = Catcher(layers[0], 0)

for d in data[:args['nsamples']]: try: model_pipeline(d, num_inference_steps=2) except Exception as e: pass

HoangHoang1408 commented 4 days ago

layers[0] = layers[0].module

if args['use_vq']: QClass = lambda: VQQuantizer( vq_dim=args['vq_dim'], columns_per_group=args['columns_per_group'], vq_scaling_blocksize=args['vq_scaling_blocksize'], vq_scaling_norm=args['vq_scaling_norm'], vq_scaling_n_bits=args['vq_scaling_n_bits'], vq_scaling_domain=args['vq_scaling_domain'], kmeans_init_method=args['kmeans_init_method'], assignment_chunk_size=args['assignment_chunk_size'], kmeans_iters=args['kmeans_iters'], codebook_bitwidth=args['codebook_bitwidth'], quantize_per_codebook=args['quantize_per_codebook'], quantize_during_kmeans=args['quantize_during_kmeans'], n_subsample=args['n_subsample'], ) else: QClass = Quantizer

HoangHoang1408 commented 4 days ago

transformer_quantizers = {} for layer_index in range(len(layers)): layer = layers[layer_index] full = find_layers(layer) if args["true_sequential"]: sequential = [ ["attn1.to_k", "attn1.to_v", "attn1.to_q"], ["attn1.to_out.0"], ["attn2.to_k", "attn2.to_v", "attn2.to_q"], ["attn2.to_out.0"], ['ff.net.0.proj'], ['ff.net.2'], ] else: sequential = [[k for k in list(full.keys()) if "block_sparse_moe.gate" not in k]]

for names in sequential:
    subset = {n: full[n] for n in names}
    gptq = {}
    for name in subset:
        gptq[name] = GPTQ(subset[name])
        gptq[name].quantizer = QClass()
        gptq[name].quantizer.configure(args["wbits"], perchannel=True, sym=args["sym"], mse=False)

    def add_batch(name):
        def tmp(_, inp, out):
            gptq[name].add_batch(inp[0].data, out.data)
        return tmp

    handles = []
    for name in subset:
        handles.append(subset[name].register_forward_hook(add_batch(name)))
    for j in range(args["nsamples"]):
        outputs[j] = layer(inputs[j]['input'], **{k: v for k, v in inputs[j].items() if k != 'input'})
    for h in handles:
        h.remove()

    for name in subset:
        print(layer_index, name)
        print("Quantizing ...")
        gptq[name].fasterquant(
            percdamp=args["percdamp"],
            groupsize=args["groupsize"],
            actorder=args["act_order"],
            static_groups=args["static_groups"],
            include_m_step=args["include_m_step"],
            use_vq=args["use_vq"],
            svd_rank=args["svd_rank"],
            hessian_weighted_lookups=args["hessian_weighted_lookups"],
            only_init_kmeans=args["only_init_kmeans"],
        )
        transformer_quantizers["model.layers.%d.%s" % (layer_index, name)] = gptq[name].quantizer
        gptq[name].free()

for j in range(args["nsamples"]):
    outputs[j] = layer(inputs[j]['input'], **{k: v for k, v in inputs[j].items() if k != 'input'})

layers[layer_index] = layer.cpu()
del layer
del gptq
torch.cuda.empty_cache()

for k, v in outputs.items():
    inputs[k]['input'] = v