VainF / Diff-Pruning

[NeurIPS 2023] Structural Pruning for Diffusion Models
Apache License 2.0
159 stars 10 forks source link

Diff pruning and Training on LDM #6

Open FATE4869 opened 9 months ago

FATE4869 commented 9 months ago

Hi, thank you for publishing this amazing work about the structural pruning on diffusion models. I wondered if you are also publishing the code for diff-pruning and training on LDM. The current ldm-prune code only supports random, magnitude and reinit pruning. Thx!

VainF commented 9 months ago

Hi @FATE4869, we perform LDM Pruning & Finetuning based on the official repo. The pruning code we used:

python prune_ldm.py --sparsity 0.3 --pruner diff-pruning

The contents of prune_ldm.py:

import sys
sys.path.append(".")
sys.path.append('./taming-transformers')
from taming.models import vqgan 
import argparse
from ldm.modules.attention import CrossAttention

parser = argparse.ArgumentParser()
parser.add_argument("--sparsity", type=float, default=0.0)
parser.add_argument("--pruner", type=str, choices=["magnitude", "random", "taylor", "diff-pruning", "reinit", "diff0"], default="magnitude")
args = parser.parse_args()

#@title loading utils
import torch
from omegaconf import OmegaConf

from ldm.util import instantiate_from_config

import torch_pruning as tp

def load_model_from_config(config, ckpt):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    model.cuda()
    model.eval()
    return model

def get_model():
    config = OmegaConf.load("configs/latent-diffusion/cin256-v2.yaml")  
    model = load_model_from_config(config, "models/ldm/cin256-v2/model.ckpt")
    return model

from ldm.models.diffusion.ddim import DDIMSampler

model = get_model()
sampler = DDIMSampler(model)

import numpy as np 
from PIL import Image
from einops import rearrange
from torchvision.utils import make_grid

classes = [25, 187, 448, 992]   # define classes to be sampled here
n_samples_per_class = 6

ddim_steps = 20
ddim_eta = 0.0
scale = 3.0   # for unconditional guidance

print(model)

print("Pruning ...")
model.eval()

if args.pruner == "magnitude":
    imp = tp.importance.MagnitudeImportance()
elif args.pruner == "random":
    imp = tp.importance.RandomImportance()
elif args.pruner == 'taylor':
    imp = tp.importance.TaylorImportance(multivariable=True) # standard first-order taylor expansion
elif args.pruner == 'diff-pruning' or args.pruner == 'diff0':
    imp = tp.importance.TaylorImportance(multivariable=False) # a modified version, estimating the accumulated error of weight removal
else:
    raise ValueError(f"Unknown pruner '{args.pruner}'")

ignored_layers = [model.model.diffusion_model.out]
channel_groups = {}
iterative_steps = 1
uc = model.get_learned_conditioning(
            {model.cond_stage_key: torch.tensor(n_samples_per_class*[1000]).to(model.device)}
            )

for m in model.model.diffusion_model.modules():
    if isinstance(m, CrossAttention):
        channel_groups[m.to_q] = m.heads
        channel_groups[m.to_k] = m.heads
        channel_groups[m.to_v] = m.heads

xc = torch.tensor(n_samples_per_class*[classes[0]])
c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})
example_inputs = {"x": torch.randn(n_samples_per_class, 3, 64, 64).to(model.device), "timesteps": torch.full((n_samples_per_class,), 1, device=model.device, dtype=torch.long), "context": c}
base_macs, base_params = tp.utils.count_ops_and_params(model.model.diffusion_model, example_inputs)
pruner = tp.pruner.MagnitudePruner(
    model.model.diffusion_model,
    example_inputs,
    importance=imp,
    iterative_steps=1,
    channel_groups =channel_groups,
    ch_sparsity=args.sparsity, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    ignored_layers=ignored_layers,
    root_module_types=[torch.nn.Conv2d, torch.nn.Linear],
    round_to=2
)
model.zero_grad()

import random
max_loss = -1
for t in range(1000):
    if args.pruner not in ['diff-pruning', 'taylor', 'diff0']:
        break
    xc = torch.tensor(random.sample(range(1000), n_samples_per_class))
    #xc = torch.tensor(n_samples_per_class*[class_label])
    c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})
    samples_ddim, _ = sampler.sample(S=ddim_steps,
                                    conditioning=c,
                                    batch_size=n_samples_per_class,
                                    shape=[3, 64, 64],
                                    verbose=False,
                                    unconditional_guidance_scale=scale,
                                    unconditional_conditioning=uc, 
                                    eta=ddim_eta)

    encoded = model.encode_first_stage(samples_ddim)
    example_inputs = {"x": encoded.to(model.device), "timesteps": torch.full((n_samples_per_class,), t, device=model.device, dtype=torch.long), "context": c}
    loss = model.get_loss_at_t(example_inputs['x'], {model.cond_stage_key: xc.to(model.device)}, example_inputs['timesteps'])
    loss = loss[0]
    if loss > max_loss:
        max_loss = loss
    thres = 0.1 if args.pruner == 'diff-pruning' else 0.0
    if args.pruner == 'diff-pruning' or args.pruner == 'diff0':
        if loss / max_loss<thres:
            break
    print(t, (loss / max_loss).item(), loss.item(), max_loss.item())
    loss.backward()
pruner.step() 

print("After pruning")
print(model)

pruend_macs, pruned_params = tp.utils.count_ops_and_params(model.model.diffusion_model, example_inputs)
print(f"MACs: {pruend_macs / base_macs * 100:.2f}%, {base_macs / 1e9:.2f}G => {pruend_macs / 1e9:.2f}G")
print(f"Params: {pruned_params / base_params * 100:.2f}%, {base_params / 1e6:.2f}M => {pruned_params / 1e6:.2f}M")

all_samples = list()

with torch.no_grad():
    with model.ema_scope():
        uc = model.get_learned_conditioning(
            {model.cond_stage_key: torch.tensor(n_samples_per_class*[1000]).to(model.device)}
            )

        for class_label in classes:
            print(f"rendering {n_samples_per_class} examples of class '{class_label}' in {ddim_steps} steps and using s={scale:.2f}.")
            xc = torch.tensor(n_samples_per_class*[class_label])
            c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})

            samples_ddim, _ = sampler.sample(S=ddim_steps,
                                             conditioning=c,
                                             batch_size=n_samples_per_class,
                                             shape=[3, 64, 64],
                                             verbose=False,
                                             unconditional_guidance_scale=scale,
                                             unconditional_conditioning=uc, 
                                             eta=ddim_eta)

            x_samples_ddim = model.decode_first_stage(samples_ddim)
            x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, 
                                         min=0.0, max=1.0)
            all_samples.append(x_samples_ddim)

# display as grid
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=n_samples_per_class)

# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
img = Image.fromarray(grid.astype(np.uint8))
img.save("samples.png")

print("Saving pruned model ...")
torch.save(model, "logs/pruned_model_{}_{}.pt".format(args.sparsity, args.pruner))

The LDM project is a bit complicated so we have not included LDM Pruning and finetuning in this repo. We also attached the pruning and fine-tuning code in code.zip

VainF commented 9 months ago

And the code for sampling:

python sample_for_FID.py --output run/samples --batch_size 10
import sys, os
sys.path.append(".")
sys.path.append('./taming-transformers')
from taming.models import vqgan 
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--pruned_model", type=str, default=None)
parser.add_argument("--finetuned_ckpt", type=str, default=None)
parser.add_argument("--ipc", type=int, default=50)
parser.add_argument("--output", type=str, default='run')
parser.add_argument("--batch_size", type=int, default=50)

args = parser.parse_args()

#@title loading utils
import torch
from omegaconf import OmegaConf

from ldm.util import instantiate_from_config

import torch_pruning as tp

from ldm.models.diffusion.ddim import DDIMSampler

def load_model_from_config(config, ckpt):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    model.cuda()
    model.eval()
    return model

def get_model():
    config = OmegaConf.load("configs/latent-diffusion/cin256-v2.yaml")  
    model = load_model_from_config(config, "models/ldm/cin256-v2/model.ckpt")
    return model

if args.pruned_model is None:
    model = get_model()
else:
    print("Loading model from ", args.pruned_model)
    model = torch.load(args.pruned_model, map_location="cpu")
    print("Loading finetuned parameters from ", args.finetuned_ckpt)
    pl_sd = torch.load(args.finetuned_ckpt, map_location="cpu")
    sd = pl_sd["state_dict"]
    m, u = model.load_state_dict(sd, strict=False)
model.cuda()
print(model)
sampler = DDIMSampler(model)

num_params = sum(p.numel() for p in model.parameters())
print("Number of parameters: {}", num_params/1000000, "M")

import numpy as np 
from PIL import Image
from einops import rearrange
from torchvision.utils import make_grid

classes = range(1000)   # define classes to be sampled here
n_samples_per_class = args.batch_size
n_batch_per_class = args.ipc // args.batch_size

ddim_steps = 250
ddim_eta = 0.0
scale = 3.0   # for unconditional guidance

all_samples = list()

from torchvision import utils as tvu
os.makedirs(args.output, exist_ok=True)

img_id = 0
with torch.no_grad():
    with model.ema_scope():
        uc = model.get_learned_conditioning(
            {model.cond_stage_key: torch.tensor(n_samples_per_class*[1000]).to(model.device)}
        )

        for _ in range(n_batch_per_class):
            for class_label in classes:
                print(f"rendering {n_samples_per_class} examples of class '{class_label}' in {ddim_steps} steps and using s={scale:.2f}.")
                xc = torch.tensor(n_samples_per_class*[class_label])
                c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})

                samples_ddim, _ = sampler.sample(S=ddim_steps,
                                                conditioning=c,
                                                batch_size=n_samples_per_class,
                                                shape=[3, 64, 64],
                                                verbose=False,
                                                unconditional_guidance_scale=scale,
                                                unconditional_conditioning=uc, 
                                                eta=ddim_eta)

                x_samples_ddim = model.decode_first_stage(samples_ddim)
                x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, 
                                            min=0.0, max=1.0)
                #all_samples.append(x_samples_ddim)
                for i in range(len(x_samples_ddim)):
                    tvu.save_image(
                        x_samples_ddim[i], os.path.join(args.output, f"{class_label}_{img_id}.png")
                    )
                    img_id += 1
jonathanyang0227 commented 9 months ago

hi @VainF , thank you for providing these works, I got this error from running the above code you provided, can you help me check this? thanks a lot !!!

Screenshot 2024-01-09 at 12 25 50 PM
VainF commented 9 months ago

I guess the Idm repo has changed a lot in these months. Let me check and upload the whole LDM project.

jonathanyang0227 commented 9 months ago

I guess the Idm repo has changed a lot in these months. Let me check and upload the whole LDM project.

Thank you!!!

VainF commented 8 months ago

Hi @jonathanyang0227, I uploaded the original code for LDM. It's a bit messy.

https://github.com/VainF/Diff-Pruning/tree/main/ldm_exp

Sample images for FID:

# to generate the fid_stats_imagenet.npz file
 python fid_score.py --save-stats ~/Datasets/imagenet/train run/fid_stats_imagenet --device cuda:0 --batch-size 64 --num_samples 50000 --res 256

# sample images from the pruned LDM
python sample_for_FID.py --pruned_model logs/pruned_model_0.3_diff-pruning.pt --finetuned_ckpt logs/2023-08-06T01-06-01_cin256-v2/checkpoints/epoch=000004.ckpt --ipc 50 --output PATH_TO_YOUR_IMAGES

# FID
python fid_score.py run/fid_stats_imagenet.npz PATH_TO_YOUR_IMAGES  --device cuda:0 --batch-size 100 
jonathanyang0227 commented 8 months ago

thank you so much!!