VainF / Torch-Pruning

[CVPR 2023] DepGraph: Towards Any Structural Pruning
https://arxiv.org/abs/2301.12900
MIT License
2.69k stars 331 forks source link

Prunig DiT #300

Open watchstep opened 11 months ago

watchstep commented 11 months ago

I'm trying to prune DiT-XL-2-256 with taylor pruning. While pruning the model, the following error occurs during pruner.step().

import torch
import torchvision
import torch.nn.utils.prune as prune  
from torchvision import models
import torch.nn as nn
import torch_pruning as tp
import diffusers
from diffusers import LDMPipeline, DiffusionPipeline, DiTPipeline, DDIMPipeline, DDIMScheduler
import os
import random
import argparse

parser = argparse.ArgumentParser()

parser.add_argument("--model_path", type=str, default="./DiT-XL-2-256")
parser.add_argument("--save_path", type=str, default="./run/dit_256_pruned/")
parser.add_argument("--pruning_ratio", type=float, default=0.3)
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--device", type=str, default='cuda')
parser.add_argument("--pruning_type", type=str, default='taylor', choices=['taylor', 'random', 'l1', 'l2'])

args = parser.parse_args([])

pipeline = DiffusionPipeline.from_pretrained(args.model_path)
scheduler = pipeline.scheduler
transformer = pipeline.transformer.eval()
vae = pipeline.vae.eval()

torch_device = torch.device(args.device) if torch.cuda.is_available() else "cpu"

transformer.to(torch_device)
vae.to(torch_device)

example_inputs = {'hidden_states': torch.randn(1, transformer.in_channels, transformer.sample_size, transformer.sample_size).to(torch_device), 'timestep': torch.LongTensor([1]).to(torch_device),
                  'class_labels': torch.LongTensor([1, 1000]).to(torch_device)}

if args.pruning_type == 'taylor':
    imp = tp.importance.TaylorImportance()
elif args.pruning_type == 'random':
    imp = tp.importance.RandomImportance()
elif args.pruning_type == 'l1':
    imp = tp.importance.MagnitudeImportance(p=1)
elif args.pruning_type == 'l2':
    imp = tp.importance.MagnitudeImportance(p=2)
else:
    raise NotImplementedError

ignored_layers = []
for m in transformer.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 32:
        ignored_layers.append(m)

# unwrapped_parameters = [transformer.pos_embed]
from diffusers.models.attention import Attention
num_heads = {}
for m in transformer.modules():
    if isinstance(m, Attention):
        num_heads[m.to_q] = m.heads
        num_heads[m.to_k] = m.heads
        num_heads[m.to_v] = m.heads

pruner = tp.pruner.MagnitudePruner(
    transformer,
    example_inputs,
    importance=imp,
    iterative_steps=1,
    pruning_ratio=args.pruning_ratio,
    ignored_layers=ignored_layers,
    # num_heads=num_heads,
    # output_transform=lambda out: out.logits.sum(),
    # prune_head_dims=True,
    # prune_num_heads=False, 
)

base_macs, base_params = tp.utils.count_ops_and_params(transformer, example_inputs)
transformer.zero_grad()
transformer.eval()

for g in pruner.step(interactive=True):
    print(g)
    g.prune()

When I printed g, I received the following result: "Warning! No positional inputs found for a module, assuming batch size is 1."

The error occurred as follows. image

VainF commented 11 months ago

Hi @watchstep, I have not tested TP with DiT. Could you uncomment the "num_heads=num_heads"? This parameter is required for transformer pruning.