Open watchstep opened 11 months ago
I'm trying to prune DiT-XL-2-256 with taylor pruning. While pruning the model, the following error occurs during pruner.step().
pruner.step()
import torch import torchvision import torch.nn.utils.prune as prune from torchvision import models import torch.nn as nn import torch_pruning as tp import diffusers from diffusers import LDMPipeline, DiffusionPipeline, DiTPipeline, DDIMPipeline, DDIMScheduler import os import random import argparse parser = argparse.ArgumentParser() parser.add_argument("--model_path", type=str, default="./DiT-XL-2-256") parser.add_argument("--save_path", type=str, default="./run/dit_256_pruned/") parser.add_argument("--pruning_ratio", type=float, default=0.3) parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--device", type=str, default='cuda') parser.add_argument("--pruning_type", type=str, default='taylor', choices=['taylor', 'random', 'l1', 'l2']) args = parser.parse_args([]) pipeline = DiffusionPipeline.from_pretrained(args.model_path) scheduler = pipeline.scheduler transformer = pipeline.transformer.eval() vae = pipeline.vae.eval() torch_device = torch.device(args.device) if torch.cuda.is_available() else "cpu" transformer.to(torch_device) vae.to(torch_device) example_inputs = {'hidden_states': torch.randn(1, transformer.in_channels, transformer.sample_size, transformer.sample_size).to(torch_device), 'timestep': torch.LongTensor([1]).to(torch_device), 'class_labels': torch.LongTensor([1, 1000]).to(torch_device)} if args.pruning_type == 'taylor': imp = tp.importance.TaylorImportance() elif args.pruning_type == 'random': imp = tp.importance.RandomImportance() elif args.pruning_type == 'l1': imp = tp.importance.MagnitudeImportance(p=1) elif args.pruning_type == 'l2': imp = tp.importance.MagnitudeImportance(p=2) else: raise NotImplementedError ignored_layers = [] for m in transformer.modules(): if isinstance(m, torch.nn.Linear) and m.out_features == 32: ignored_layers.append(m) # unwrapped_parameters = [transformer.pos_embed] from diffusers.models.attention import Attention num_heads = {} for m in transformer.modules(): if isinstance(m, Attention): num_heads[m.to_q] = m.heads num_heads[m.to_k] = m.heads num_heads[m.to_v] = m.heads pruner = tp.pruner.MagnitudePruner( transformer, example_inputs, importance=imp, iterative_steps=1, pruning_ratio=args.pruning_ratio, ignored_layers=ignored_layers, # num_heads=num_heads, # output_transform=lambda out: out.logits.sum(), # prune_head_dims=True, # prune_num_heads=False, ) base_macs, base_params = tp.utils.count_ops_and_params(transformer, example_inputs) transformer.zero_grad() transformer.eval() for g in pruner.step(interactive=True): print(g) g.prune()
When I printed g, I received the following result: "Warning! No positional inputs found for a module, assuming batch size is 1."
g
The error occurred as follows.
Hi @watchstep, I have not tested TP with DiT. Could you uncomment the "num_heads=num_heads"? This parameter is required for transformer pruning.
I'm trying to prune DiT-XL-2-256 with taylor pruning. While pruning the model, the following error occurs during
pruner.step()
.When I printed
g
, I received the following result: "Warning! No positional inputs found for a module, assuming batch size is 1."The error occurred as follows.