microsoft / nni

An open source AutoML toolkit for automate machine learning lifecycle, including feature engineering, neural architecture search, model compression and hyper-parameter tuning.
https://nni.readthedocs.io
MIT License
14.02k stars 1.81k forks source link

“assert len(args) >= len(self.undetermined) AssertionError” when pruning and speedup the diffusion model using LinearPruner #5524

Open DragonDRLI opened 1 year ago

DragonDRLI commented 1 year ago

Describe the bug: “assert len(args) >= len(self.undetermined) AssertionError” when pruning and speedup the diffusion model using LinearPruner

Environment:

Reproduce the problem

def main(): args = parse_args()

instantiate the text_encoder and the tokenizer

text_encoder = CLIPTextModel.from_pretrained(
    args.pretrained_model_name_or_path,
    subfolder='text_encoder',
    revision=args.revision,
)
tokenizer = CLIPTokenizer.from_pretrained(
    args.pretrained_model_name_or_path,
    subfolder='tokenizer',
    revision=args.revision
)
# instantiate VAE
vae = AutoencoderKL.from_pretrained(
    args.pretrained_model_name_or_path,
    subfolder="vae",
    revision=args.revision,
)
# instantiate unet
unet = UNet2DConditionModel.from_pretrained(
    args.pretrained_model_name_or_path,
    subfolder="unet",
    revision=args.revision,
)
# freeze the weights of VAE and the weights of the text_encoder
vae.requires_grad_(False)
text_encoder.requires_grad_(False)

weight_dtype = torch.float32

device = 'cuda:0'
unet.to(device, dtype=weight_dtype)
vae.to(device, dtype=weight_dtype)
text_encoder.to(device, dtype=weight_dtype)
vae_decoder = vae.decoder
# Pruning configuration
exculde_op_names_db = ['down_blocks.{0}.attentions.{1}.proj_in'.format(i, j) for i in range(3) for j in
                       range(2)] + ['down_blocks.{0}.resnets.{1}.conv1'.format(i, j) for i in range(4) for j in
                                    range(2)]
exculde_op_names_mb = ['mid_block.attentions.0.proj_in', 'mid_block.resnets.0.conv1', 'mid_block.resnets.1.conv1']
exculde_op_names_ub = ['up_blocks.{0}.attentions.{1}.proj_in'.format(i, j) for i in range(1, 4) for j in
                       range(3)] + ['up_blocks.{0}.resnets.{1}.conv1'.format(i, j) for i in range(4) for j in
                                    range(3)]
exculde_op_names_du = ['down_blocks.{}.downsamplers.0.conv'.format(i) for i in range(3)] + [
    'up_blocks.{}.upsamplers.0.conv'.format(i) for i in range(3)]
exculde_op_names = exculde_op_names_db + exculde_op_names_mb + exculde_op_names_ub + exculde_op_names_ub

config_list = [{'op_types': ['Conv2d'],
                'sparsity_per_layer': 0.2},
               {'exclude': True,
                'op_names': exculde_op_names}]

def finetuner(unet):
    """
    Used to finetune the model has been speeded
    """
    unet.train()
    optimizer = torch.optim.AdamW(
        unet.parameters(),
        lr=args.learning_rate,
        betas=(0.9, 0.999),
        weight_decay=args.adam_weight_decay,
        eps=args.adam_epsilon)
    for step, batch in enumerate(train_dataloader):
        latents = vae.encode(batch["pixel_values"].to(weight_dtype).to(device)).latent_dist.sample()
        latents = latents * 0.18215
        noise = torch.randn_like(latents).to(
            device)  # + 0.3 * torch.randn(latents.shape[0], latents.shape[1], 1, 1).to(latents.device)
        bsz = latents.shape[0]
        timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
        timesteps = timesteps.long()
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps).to(device)
        batch["input_ids"] = batch["input_ids"].to(device)
        encoder_output = text_encoder(batch["input_ids"])
        encoder_hidden_states = encoder_output[0]
        if noise_scheduler.config.prediction_type == "epsilon":
            target = noise
        elif noise_scheduler.config.prediction_type == "v_prediction":
            target = noise_scheduler.get_velocity(latents, noise, timesteps)
        else:
            raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
        optimizer.zero_grad()
        model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
        loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
        loss.backward()
        optimizer.step()

noisy_input = torch.rand(1, 4, 56, 88).to(device)
timestep_input = torch.randint(size=(1,), low=0, high=50).to(device)
text_input = torch.rand(1, 77, 768).to(device)
dummy_input = (noisy_input, timestep_input, text_input)

kw_args = {'pruning_algorithm': 'fpgm',
           'total_iteration': 2,
           'evaluator': None,
           'finetuner': finetuner,
           'speedup': True,
           'dummy_input': dummy_input}

pruner = LinearPruner(unet, config_list, **kw_args)
pruner.compress()
_, unet, masks, _, _ = pruner.get_best_result()


- How to reproduce:
customize the path of models(unet、CLIPTextModel、CLIPTokenizer、VAE):args.pretrained_model_name_or_path
- log message:
[2023-04-20 20:25:47] Update mask for down_blocks.0.attentions.0.transformer_blocks.0.attn2.aten::reshape.671
[2023-04-20 20:25:47] WARNING: throw some args away when calling the function "reshape"
[2023-04-20 20:25:47] WARNING: throw some args away when calling the function "reshape"
[2023-04-20 20:25:47] Update mask for down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0
[2023-04-20 20:25:47] Update mask for down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.1
[2023-04-20 20:25:47] Update mask for down_blocks.0.attentions.0.transformer_blocks.0.aten::add.614
[2023-04-20 20:25:47] Update mask for down_blocks.0.attentions.0.transformer_blocks.0.norm3
[2023-04-20 20:25:47] Update mask for down_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj
[2023-04-20 20:25:48] Update mask for down_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.aten::chunk.672
[2023-04-20 20:25:49] Update mask for down_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.prim::ListUnpack.673
[2023-04-20 20:25:49] Update mask for down_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.aten::gelu.674
Traceback (most recent call last):
  File "text2image.py", line 259, in <module>
    main()
  File "text2image.py", line 218, in main
    pruner.compress()
  File "/opt/conda/lib/python3.8/site-packages/nni/algorithms/compression/v2/pytorch/base/scheduler.py", line 194, in compress
    task_result = self.pruning_one_step(task)
  File "/opt/conda/lib/python3.8/site-packages/nni/algorithms/compression/v2/pytorch/pruning/basic_scheduler.py", line 283, in pruning_one_step
    result = self.pruning_one_step_normal(task)
  File "/opt/conda/lib/python3.8/site-packages/nni/algorithms/compression/v2/pytorch/pruning/basic_scheduler.py", line 154, in pruning_one_step_normal
    ModelSpeedup(compact_model, self.dummy_input, pruner_generated_masks).speedup_model()
  File "/opt/conda/lib/python3.8/site-packages/nni/compression/pytorch/speedup/compressor.py", line 546, in speedup_model
    self.infer_modules_masks()
  File "/opt/conda/lib/python3.8/site-packages/nni/compression/pytorch/speedup/compressor.py", line 383, in infer_modules_masks
    self.update_direct_sparsity(curnode)
  File "/opt/conda/lib/python3.8/site-packages/nni/compression/pytorch/speedup/compressor.py", line 237, in update_direct_sparsity
    _auto_infer = AutoMaskInference(
  File "/opt/conda/lib/python3.8/site-packages/nni/compression/pytorch/speedup/infer_mask.py", line 80, in __init__
    self.output = self.module(*dummy_input)
  File "/opt/conda/lib/python3.8/site-packages/nni/compression/pytorch/speedup/jit_translate.py", line 227, in __call__
    assert len(args) >= len(self.undetermined)
AssertionError
J-shang commented 1 year ago

hello @DragonDRLI , we are working on to fix this issue, we will let you know after it is done.

DragonDRLI commented 1 year ago

Thank you for your response. I really appreciate your efforts and support. I am very eager to solve this problem as soon as possible because it is very important to my work. Thank you again for your patience and assistance! @J-shang

DragonDRLI commented 1 year ago

Can you give me some suggestions on this issue now? What may be the possible causes of this problem?@J-shang

J-shang commented 1 year ago

hello @DragonDRLI , we have released a new nni version, please have a try with pip install nni=3.0rc1, here is an example for using new LinearPruner, https://github.com/microsoft/nni/blob/master/examples/compression/pruning/scheduled_pruning.py

If speedup v2 still have issue, feel free to contact us.

Lijiaoa commented 1 year ago

Hi, @DragonDRLI, any updates for it?

DragonDRLI commented 1 year ago

hi, @Lijiaoa @J-shang I haven't tried the solution you provided because I have a new question now, which is whether my script is correct. It was written with reference to another script that uses LinearPruner, which is different from the one you gave me. Are the usage methods demonstrated in these two scripts correct?