mingyuanzhou / SiD-LSG

Score identity Distillation with Long and Short Guidance for One-Step Text-to-Image Generation
Apache License 2.0
33 stars 2 forks source link

some color spots appeared on the face #3

Open koking0 opened 3 days ago

koking0 commented 3 days ago

Has anyone encountered the following problem? I used SiD-LSG to distill an SDXL model (made some code adaptations to the text-encoder), and some color spots appeared on the face, which were very obvious when zoomed in.

image

image

mingyuanzhou commented 3 days ago

Could you please provide more details, such as hardware used to distill, fp16 or fp32, and how many fake images are iterated to train the generator?

koking0 commented 2 days ago

of course

$ wget https://raw.githubusercontent.com/pytorch/pytorch/main/torch/utils/collect_env.py
$ python3.9 collect_env.py
Collecting environment information...
PyTorch version: 2.3.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.6 LTS (x86_64)
GCC version: (GCC) 8.2.0
Clang version: 3.8.0 (tags/RELEASE_380/final)
CMake version: version 3.28.1
Libc version: glibc-2.27

Python version: 3.9.13 (main, May 23 2022, 22:02:02)  [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.10.0-1.0.0.28-x86_64-with-glibc2.27
Is CUDA available: True
CUDA runtime version: 11.7.64
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA A800-SXM4-80GB
GPU 1: NVIDIA A800-SXM4-80GB
GPU 2: NVIDIA A800-SXM4-80GB
GPU 3: NVIDIA A800-SXM4-80GB
GPU 4: NVIDIA A800-SXM4-80GB
GPU 5: NVIDIA A800-SXM4-80GB
GPU 6: NVIDIA A800-SXM4-80GB
GPU 7: NVIDIA A800-SXM4-80GB

Nvidia driver version: 525.125.06
cuDNN version: Probably one of the following:
/usr/lib/libcudnn.so.8.4.1
/usr/lib/libcudnn_adv_infer.so.8.4.1
/usr/lib/libcudnn_adv_train.so.8.4.1
/usr/lib/libcudnn_cnn_infer.so.8.4.1
/usr/lib/libcudnn_cnn_train.so.8.4.1
/usr/lib/libcudnn_ops_infer.so.8.4.1
/usr/lib/libcudnn_ops_train.so.8.4.1
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:        x86_64
CPU op-mode(s):      32-bit, 64-bit
Byte Order:          Little Endian
CPU(s):              160
On-line CPU(s) list: 0-159
Thread(s) per core:  2
Core(s) per socket:  20
Socket(s):           4
NUMA node(s):        4
Vendor ID:           GenuineIntel
CPU family:          6
Model:               85
Model name:          Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz
Stepping:            7
CPU MHz:             3200.000
CPU max MHz:         3900.0000
CPU min MHz:         1000.0000
BogoMIPS:            5000.00
Virtualization:      VT-x
L1d cache:           32K
L1i cache:           32K
L2 cache:            1024K
L3 cache:            28160K
NUMA node0 CPU(s):   0-19,80-99
NUMA node1 CPU(s):   20-39,100-119
NUMA node2 CPU(s):   40-59,120-139
NUMA node3 CPU(s):   60-79,140-159
Flags:               fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts pku avx512_vnni md_clear flush_l1d arch_capabilities

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.1.3.1
[pip3] nvidia-cuda-cupti-cu12==12.1.105
[pip3] nvidia-cuda-nvrtc-cu12==12.1.105
[pip3] nvidia-cuda-runtime-cu12==12.1.105
[pip3] nvidia-cudnn-cu12==8.9.2.26
[pip3] nvidia-cufft-cu12==11.0.2.54
[pip3] nvidia-curand-cu12==10.3.2.106
[pip3] nvidia-cusolver-cu12==11.4.5.107
[pip3] nvidia-cusparse-cu12==12.1.0.106
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.1.105
[pip3] open-clip-torch==2.24.0
[pip3] torch==2.3.0
[pip3] torchvision==0.18.0
[pip3] triton==2.3.0
[conda] Could not collect

I mainly distilled the following model: https://civitai.com/models/112902/dreamshaper-xl

The main code changes are as follows.

  1. run_sid.sh: modified sd_model, sd_model
    torchrun --standalone --nproc_per_node=8 sid_train.py \
    --outdir 'image_experiment/sid-lsg-train-runs/' \
    --train_mode 1 \
    --cfg_train_fake 2 \
    --cfg_eval_fake 2 \
    --cfg_eval_real 2 \
    --optimizer 'adam' \
    --data_prompt_text '/root/paddlejob/workspace/SiD-LSG-main/datasets' \
    --resolution 1024 \
    --alpha 1 \
    --init_timestep 625 \
    --batch 512 \
    --fp16 1 \
    --batch-gpu 1 \
    --sd_model "/root/paddlejob/workspace/models/DreamShaper_xl_v2_1" \
    --tick 2 \
    --snap 50 \
    --dump 100 \
    --lr 0.000001 \
    --glr 0.000001 \
    --duration 10 \
    --enable_xformers 1 \
    --gradient_checkpointing 1 \
    --ema 0
  1. sid_train.py: due to insufficient memory, Adam8bit was used
    if opts.optimizer=='adam':
        c.fake_score_optimizer_kwargs = dnnlib.EasyDict(class_name='bitsandbytes.optim.Adam8bit', lr=opts.lr, betas=[0.0, 0.999], eps = 1e-8 if not opts.fp16 else 1e-6)
        c.g_optimizer_kwargs = dnnlib.EasyDict(class_name='bitsandbytes.optim.Adam8bit', lr=opts.glr, betas=[0.0, 0.999], eps = 1e-8 if not opts.fp16 else 1e-6)
    else:
        #this is another optimizer to choose; it could provide better performance, but we have not carefully tested it yet
        assert opts.optimizer=='adamw'
        c.fake_score_optimizer_kwargs = dnnlib.EasyDict(class_name='torch.optim.AdamW', lr=opts.lr, betas=[0.0, 0.999], eps = 1e-8 if not opts.fp16 else 1e-6,weight_decay=0.01)
        c.g_optimizer_kwargs = dnnlib.EasyDict(class_name='torch.optim.AdamW', lr=opts.glr, betas=[0.0, 0.999], eps = 1e-8 if not opts.fp16 else 1e-6,weight_decay=0.01)

    c.init_timestep = opts.init_timestep
  1. sid_training_loop.py: adapted to SD XL's text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2
    #Distill Stable Diffusion with SiD-LSG
    if train_mode:
        #Use barrier if needs to downloading the weights from internet and save to cache
        if dist.get_rank() != 0:
            torch.distributed.barrier() 
        if dtype==torch.float16:
            #if the fp16 checkpoint variant is not available, you can load the fp32 version and then convert them into fp16
            unet, vae, noise_scheduler, text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2 = load_sd15(pretrained_model_name_or_path=pretrained_model_name_or_path, pretrained_vae_model_name_or_path=None,
                                                                    device=device, weight_dtype=dtype, variant="fp16", enable_xformers=enable_xformers, lora_config=lora_config)
        else:
            unet, vae, noise_scheduler, text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2 = load_sd15(pretrained_model_name_or_path=pretrained_model_name_or_path, pretrained_vae_model_name_or_path=None,
                                                                    device=device, weight_dtype=dtype, enable_xformers=enable_xformers, lora_config=lora_config)

        if dist.get_rank() == 0:
            torch.distributed.barrier()    
        dist.print0('Loading network completed')
        dist.print0('Noise scheduler:')
        dist.print0(noise_scheduler)

        # Initialize.
        start_time = time.time()
        np.random.seed((seed * dist.get_world_size() + dist.get_rank()) % (1 << 31))
        torch.manual_seed(np.random.randint(1 << 31))
        torch.backends.cudnn.benchmark = cudnn_benchmark
        torch.backends.cudnn.allow_tf32 = False
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False

        # Select batch size per GPU. Used for gradient accumulation
        batch_gpu_total = batch_size // dist.get_world_size()
        if batch_gpu is None or batch_gpu > batch_gpu_total:
            batch_gpu = batch_gpu_total
        num_accumulation_rounds = batch_gpu_total // batch_gpu
        assert batch_size == batch_gpu * num_accumulation_rounds * dist.get_world_size()

        # Parameters for latent diffusion 
        latent_img_channels = 4
        vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
        latent_resolution = resolution // vae_scale_factor

        #Prepare the random noise used for example image generation during training
        if dist.get_rank() == 0:
            original_seed = torch.initial_seed()

            # Set a temporary random seed
            temporary_seed = 2024
            torch.manual_seed(temporary_seed)
            grid_size, images, contexts = setup_snapshot_image_grid(training_set=dataset_obj)
            #contexts = [""] * len(contexts)
            grid_z = torch.randn([len(contexts), latent_img_channels, latent_resolution, latent_resolution], device=device, dtype=dtype)
            grid_z = grid_z.split(batch_gpu)
            grid_c = split_list(contexts, batch_gpu)
            # Revert back to the original random seed
            torch.manual_seed(original_seed)

        dataset_prompt_text_obj = dnnlib.util.construct_class_by_name(**dataset_prompt_text_kwargs) # subclass of training.dataset.Dataset
        dataset_prompt_text_sampler = misc.InfiniteSampler(dataset=dataset_prompt_text_obj, rank=dist.get_rank(), num_replicas=dist.get_world_size(), seed=seed)
        dataset_prompt_text_iterator = iter(torch.utils.data.DataLoader(dataset=dataset_prompt_text_obj, sampler=dataset_prompt_text_sampler, batch_size=batch_gpu, **data_loader_kwargs))

        dist.print0("Example text prompts used for distillation:")
        for _i in range(16):
            dist.print0(_i)
            _,contexts = next(dataset_prompt_text_iterator)
            dist.print0(contexts)

        #Initilize true score net, fake score net, and generator
        true_score = unet
        true_score.eval().requires_grad_(False).to(device)
        fake_score = copy.deepcopy(true_score).train().requires_grad_(True).to(device)
        G = copy.deepcopy(true_score).train().requires_grad_(True).to(device)

        # Setup optimizer.
        dist.print0('Setting up optimizer...')
        fake_score_optimizer = dnnlib.util.construct_class_by_name(params=fake_score.parameters(), **fake_score_optimizer_kwargs) # subclass of torch.optim.Optimizer
        g_optimizer = dnnlib.util.construct_class_by_name(params=G.parameters(), **g_optimizer_kwargs) # subclass of torch.optim.Optimizer

        # Resume training from previous snapshot.
        if resume_training is not None:
            dist.print0('checkpoint path:',resume_training)
            data = torch.load(resume_training, map_location=torch.device('cpu'))
            misc.copy_params_and_buffers(src_module=data['fake_score'], dst_module=fake_score, require_all=True)
            misc.copy_params_and_buffers(src_module=data['G'], dst_module=G, require_all=True)
            if ema_halflife_kimg>0:
                G_ema = copy.deepcopy(G).eval().requires_grad_(False)
                misc.copy_params_and_buffers(src_module=data['G_ema'], dst_module=G_ema, require_all=True)
                G_ema.eval().requires_grad_(False)
            else:
                G_ema=G
            fake_score_optimizer.load_state_dict(data['fake_score_optimizer_state'])
            g_optimizer.load_state_dict(data['g_optimizer_state'])
            del data # conserve memory
            dist.print0('Loading checkpoint completed')

            torch.distributed.barrier() 

            # Setup GPU parallel computing.
            dist.print0('Setting up GPU parallel computing')
            fake_score_ddp = torch.nn.parallel.DistributedDataParallel(fake_score, device_ids=[device], broadcast_buffers=False,find_unused_parameters=False)
            G_ddp = torch.nn.parallel.DistributedDataParallel(G, device_ids=[device], broadcast_buffers=False,find_unused_parameters=False)

        else:     
            # Setup GPU parallel computing.
            dist.print0('Setting up GPU parallel computing')
            fake_score_ddp = torch.nn.parallel.DistributedDataParallel(fake_score, device_ids=[device], broadcast_buffers=False,find_unused_parameters=False)
            G_ddp = torch.nn.parallel.DistributedDataParallel(G, device_ids=[device], broadcast_buffers=False,find_unused_parameters=False)
            if ema_halflife_kimg>0:
                G_ema = copy.deepcopy(G).eval().requires_grad_(False)
            else:
                G_ema = G

        fake_score_ddp.eval().requires_grad_(False)        
        G_ddp.eval().requires_grad_(False)

        # Train.
        dist.print0(f'Training for {total_kimg} kimg...')
        dist.print0()
        cur_nimg = resume_kimg * 1000
        cur_tick = 0
        tick_start_nimg = cur_nimg
        tick_start_time = time.time()
        maintenance_time = tick_start_time - start_time
        dist.update_progress(cur_nimg // 1000, total_kimg)
        stats_jsonl = None
        stats_metrics = dict()

        if resume_training is None:
            if dist.get_rank() == 0:
                print('Exporting sample real images...')
                save_image_grid(img=images, fname=os.path.join(run_dir, 'reals.png'), drange=[0,255], grid_size=grid_size)

                print('Text prompts for example images:')
                for c in grid_c:
                    dist.print0(c)

                print('Exporting sample fake images at initialization...')
                images = [sid_sd_sampler(unet=G_ema,latents=z,contexts=c,init_timesteps = init_timestep * torch.ones((len(c),), device=device, dtype=torch.long),
                                             noise_scheduler=noise_scheduler,
                                             text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2, 
                                             tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2, 
                                             resolution=resolution,dtype=dtype,return_images=True, vae=vae,num_steps=num_steps,train_sampler=False,num_steps_eval=1) for z, c in zip(grid_z, grid_c)]
                images = torch.cat(images).cpu().numpy()
                save_image_grid(img=images, fname=os.path.join(run_dir, 'fakes_init.png'), drange=[-1,1], grid_size=grid_size)
                del images

        torch.distributed.barrier() 

        dist.print0('Start Running')
        while True:
            torch.cuda.empty_cache()
            gc.collect()
            G_ddp.eval().requires_grad_(False)
            #----------------------------------------------------------------------------------------------
            # Update Fake Score Network
            fake_score_ddp.train().requires_grad_(True)
            fake_score_optimizer.zero_grad(set_to_none=True)            
            for round_idx in range(num_accumulation_rounds):
                _, contexts = next(dataset_prompt_text_iterator)
                if use_context_dropout_train_fake:
                    bool_tensor = torch.rand(len(contexts)) < 0.1
                    # Update contexts based on bool_tensor
                    contexts = ["" if flag else context for flag, context in zip(bool_tensor.tolist(), contexts)]
                    #print(contexts)
                z = torch.randn([len(contexts), latent_img_channels, latent_resolution, latent_resolution], device=device, dtype=dtype)
                noise = torch.randn_like(z)

                # Initialize timesteps
                init_timesteps = init_timestep * torch.ones((len(contexts),), device=device, dtype=torch.long)

                # Generate fake images (stop generator gradient)
                with misc.ddp_sync(G_ddp, False):
                    with torch.no_grad():
                        images = sid_sd_sampler(unet=G_ddp,latents=z,contexts=contexts,init_timesteps=init_timesteps,
                                             noise_scheduler=noise_scheduler,
                                             text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2, 
                                             tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2, 
                                             resolution=resolution,dtype=dtype,return_images=False, vae=None,num_steps=num_steps)

                timesteps = torch.randint(tmin, tmax, (len(contexts),), device=device, dtype=torch.long)

                # Compute loss for fake score network
                with misc.ddp_sync(fake_score_ddp, (round_idx == num_accumulation_rounds - 1)):
                    #Denoised fake images (stop generator gradient) under fake score network, using guidance scale: kappa1=cfg_eval_train
                    noise_fake = sid_sd_denoise(unet=fake_score_ddp,images=images,noise=noise,contexts=contexts,timesteps=timesteps,
                                                     noise_scheduler=noise_scheduler,
                                             text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2, 
                                             tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2, 
                                                     resolution=resolution,dtype=dtype,predict_x0=False,guidance_scale=cfg_train_fake)

                    nan_mask = torch.isnan(noise_fake).flatten(start_dim=1).any(dim=1)
                    if noise_scheduler.config.prediction_type == "v_prediction":
                        target = noise_scheduler.get_velocity(images, noise, timesteps)
                        nan_mask = nan_mask | torch.isnan(target).flatten(start_dim=1).any(dim=1)

                    # Check if there are any NaN values present
                    if nan_mask.any():
                        # Invert the nan_mask to get a mask of samples without NaNs
                        non_nan_mask = ~nan_mask
                        # Filter out samples with NaNs from y_real and y_fake
                        noise_fake = noise_fake[non_nan_mask]
                        noise = noise[non_nan_mask]
                        if noise_scheduler.config.prediction_type == "v_prediction":
                            target = target[non_nan_mask]

                    if noise_scheduler.config.prediction_type == "v_prediction":
                        loss = (noise_fake-target)**2
                        snr = compute_snr(noise_scheduler, timesteps)
                        loss = loss * snr/(snr+1)
                    else:
                        loss = (noise_fake-noise)**2

                    loss=loss.sum().mul(loss_scaling / batch_gpu_total)

                    del images

                    if len(noise) > 0:
                        loss.backward()

            loss_fake_score_print = loss.item()
            training_stats.report('fake_score_Loss/loss', loss_fake_score_print)

            fake_score_ddp.eval().requires_grad_(False)

            # Update fake score network
            for param in fake_score.parameters():
                if param.grad is not None:
                    torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)

            fake_score_optimizer.step()

            #----------------------------------------------------------------------------------------------
            # Update One-Step Generator Network

            G_ddp.train().requires_grad_(True)
            g_optimizer.zero_grad(set_to_none=True)

            for round_idx in range(num_accumulation_rounds):
                _, contexts = next(dataset_prompt_text_iterator)
                if use_context_dropout_train_G:
                    bool_tensor = torch.rand(len(contexts)) < 0.1
                    # Update contexts based on bool_tensor
                    contexts = ["" if flag else context for flag, context in zip(bool_tensor.tolist(), contexts)]

                z = torch.randn([len(contexts), latent_img_channels, latent_resolution, latent_resolution], device=device, dtype=dtype)
                noise = torch.randn_like(z)

                # initialize timesteps
                init_timesteps = init_timestep * torch.ones((len(contexts),), device=device, dtype=torch.long)
                timesteps = torch.randint(tmin, tmax, (len(contexts),), device=device, dtype=torch.long)

                # Generate fake images (track generator gradient)
                with misc.ddp_sync(G_ddp, (round_idx == num_accumulation_rounds - 1)):
                    images = sid_sd_sampler(unet=G_ddp,latents=z,contexts=contexts,init_timesteps=init_timesteps,
                                         noise_scheduler=noise_scheduler,
                                             text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2, 
                                             tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2, 
                                         resolution=resolution,dtype=dtype,return_images=False,num_steps=num_steps)

                # Compute loss for generator    
                with misc.ddp_sync(fake_score_ddp, False): 
                    #Denoised fake images (track generator gradient) under fake score network, using guidance scale: kappa2=kappa3=cfg_eval_fake
                    y_fake = sid_sd_denoise(unet=fake_score_ddp,images=images,noise=noise,contexts=contexts,timesteps=timesteps,
                                             noise_scheduler=noise_scheduler,
                                             text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2, 
                                             tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2, 
                                             resolution=resolution,dtype=dtype,guidance_scale=cfg_eval_fake)

                    #Denoised fake images (track generator gradient) under pretrained score network, using guidance scale: kappa4=cfg_eval_real  
                    y_real = sid_sd_denoise(unet=true_score,images=images,noise=noise,contexts=contexts,timesteps=timesteps,
                                     noise_scheduler=noise_scheduler,
                                             text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2, 
                                             tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2, 
                                     resolution=resolution,dtype=dtype,guidance_scale=cfg_eval_real)

                    nan_mask_images = torch.isnan(images).flatten(start_dim=1).any(dim=1)
                    nan_mask_y_real = torch.isnan(y_real).flatten(start_dim=1).any(dim=1)
                    nan_mask_y_fake = torch.isnan(y_fake).flatten(start_dim=1).any(dim=1)
                    nan_mask = nan_mask_images | nan_mask_y_real | nan_mask_y_fake

                    # Check if there are any NaN values present
                    if nan_mask.any():
                        # Invert the nan_mask to get a mask of samples without NaNs
                        non_nan_mask = ~nan_mask
                        # Filter out samples with NaNs from y_real and y_fake
                        images = images[non_nan_mask]
                        y_real = y_real[non_nan_mask]
                        y_fake = y_fake[non_nan_mask]

                    with torch.no_grad():
                        weight_factor = abs(images.to(dtype) - y_real.to(dtype)).mean(dim=[1, 2, 3], keepdim=True).clip(min=0.00001)

                    if alpha==1:
                        loss = (y_real - y_fake) * (y_fake - images) / weight_factor
                    else:
                        loss = (y_real - y_fake) * ((y_real - images) - alpha * (y_real - y_fake)) / weight_factor

                    loss=loss.sum().mul(loss_scaling_G / batch_gpu_total)

                    if len(y_real) > 0:
                        loss.backward()

            lossG_print = loss.item()
            training_stats.report('G_Loss/loss', lossG_print)

            G_ddp.eval().requires_grad_(False)

            # Update generator
            for param in G.parameters():
                if param.grad is not None:
                    torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)

            #Apply gradient clipping under fp16 to prevent suddern divergence
            if dtype == torch.float16 and len(y_real) > 0:
                torch.nn.utils.clip_grad_value_(G.parameters(), 1) 

            g_optimizer.step()

            if ema_halflife_kimg>0:
                # Update EMA.
                ema_halflife_nimg = ema_halflife_kimg * 1000
                if ema_rampup_ratio is not None:
                    ema_halflife_nimg = min(ema_halflife_nimg, cur_nimg * ema_rampup_ratio)
                ema_beta = 0.5 ** (batch_size / max(ema_halflife_nimg, 1e-8))

                for p_ema, p_true_score in zip(G_ema.parameters(), G.parameters()):
                    #p_ema.copy_(p_true_score.detach().lerp(p_ema, ema_beta))
                    with torch.no_grad():  
                        p_ema.copy_(p_true_score.detach().lerp(p_ema, ema_beta))
            else:
                G_ema=G

            cur_nimg += batch_size
            done = (cur_nimg >= total_kimg * 1000)

            if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000):
                continue

            # Print status line, accumulating the same information in training_stats.
            tick_end_time = time.time()
            fields = []
            fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"]
            fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<9.1f}"]
            fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"]
            fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"]
            fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"]
            fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"]
            fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"]
            fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"]
            fields += [f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}"]
            fields += [f"loss_fake_score {training_stats.report0('fake_score_Loss/loss', loss_fake_score_print):<6.2f}"]
            fields += [f"loss_G {training_stats.report0('G_Loss/loss', lossG_print):<6.2f}"]
            torch.cuda.reset_peak_memory_stats()
            dist.print0(' '.join(fields))

            # Check for abort.
            if (not done) and dist.should_stop():
                done = True
                dist.print0()
                dist.print0('Aborting...')

            if (snapshot_ticks is not None) and (done or cur_tick % snapshot_ticks == 0 or cur_tick in [2,4,10,20,30,40,50,60,70,80,90,100]):

                dist.print0('Exporting sample images...')
                if dist.get_rank() == 0:
                    for num_steps_eval in [1,2,4]:
                        #While the generator is primarily trained to generate images in a single step, it can also be utilized in a multi-step setting during evaluation.
                        #To do: Distill a multi-step generator that is optimized for multi-step settings
                        images = [sid_sd_sampler(unet=G_ema,latents=z,contexts=c,init_timesteps=init_timestep * torch.ones((len(c),), device=device, dtype=torch.long),
                                                 noise_scheduler=noise_scheduler,
                                             text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2, 
                                             tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2, 
                                                 resolution=resolution,dtype=dtype,return_images=True, vae=vae,num_steps=num_steps,train_sampler=False,num_steps_eval=num_steps_eval) for z, c in zip(grid_z, grid_c)]
                        images = torch.cat(images).cpu().numpy()

                        # if cur_tick==0:
                        #     dist.print0(contexts[0])
                        #     #dist.print0(images[0])

                        save_image_grid(img=images, fname=os.path.join(run_dir, f'fakes_{alpha:03f}_{cur_nimg//1000:06d}_{num_steps_eval:d}.png'), drange=[-1,1], grid_size=grid_size)

                    del images

                if cur_tick>0:    
                    dist.print0('Evaluating metrics...')
                    dist.print0(metric_pt_path)     

                    if metrics is not None:    
                        for metric in metrics:

                            result_dict = metric_main.calc_metric(metric=metric, metric_pt_path=metric_pt_path, metric_open_clip_path=metric_open_clip_path, metric_clip_path=metric_clip_path,
                                G=partial(sid_sd_sampler,unet=G_ema,noise_scheduler=noise_scheduler,
                                             text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2, 
                                             tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2, 
                                                         resolution=resolution,dtype=dtype,return_images=True,vae=vae,num_steps=num_steps,train_sampler=False,num_steps_eval=1),
                                init_timestep=init_timestep,
                                dataset_kwargs=dataset_kwargs, num_gpus=dist.get_world_size(), rank=dist.get_rank(), local_rank=dist.get_local_rank(), device=device)
                            if dist.get_rank() == 0:
                                print(result_dict.results)
                                metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=f'fakes_{alpha:03f}_{cur_nimg//1000:06d}.png', alpha=alpha)          

                            stats_metrics.update(result_dict.results)

                data = dict(ema=G_ema)
                for key, value in data.items():
                    if isinstance(value, torch.nn.Module):
                        value = copy.deepcopy(value).eval().requires_grad_(False)
                        # misc.check_ddp_consistency(value)
                        data[key] = value.cpu()
                    del value # conserve memory

                if dist.get_rank() == 0:
                    save_data(data=data, fname=os.path.join(run_dir, f'network-snapshot-{alpha:03f}-{cur_nimg//1000:06d}.pkl'))

                del data # conserve memory

            if (state_dump_ticks is not None) and (done or cur_tick % state_dump_ticks == 0) and cur_tick != 0 and dist.get_rank() == 0:
                dist.print0(f'saving checkpoint: training-state-{cur_nimg//1000:06d}.pt')
                save_pt(pt=dict(fake_score=fake_score, G=G, G_ema=G_ema, fake_score_optimizer_state=fake_score_optimizer.state_dict(), g_optimizer_state=g_optimizer.state_dict()), fname=os.path.join(run_dir, f'training-state-{cur_nimg//1000:06d}.pt'))

            # Update logs.
            training_stats.default_collector.update()
            if dist.get_rank() == 0:
                if stats_jsonl is None:
                    append_line(jsonl_line=json.dumps(dict(training_stats.default_collector.as_dict(), timestamp=time.time())) + '\n', fname=os.path.join(run_dir, f'stats_{alpha:03f}.jsonl'))

            dist.update_progress(cur_nimg // 1000, total_kimg)

            # Update state.
            cur_tick += 1
            tick_start_nimg = cur_nimg
            tick_start_time = time.time()
            maintenance_time = tick_start_time - tick_end_time
            if done:
                break

        # Done.
        dist.print0()
        dist.print0('Exiting...')
  1. sid_sd_util.py: adapted to SD XL's text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2
def load_sd15(pretrained_model_name_or_path, pretrained_vae_model_name_or_path, device, weight_dtype, 
              revision=None, variant=None, lora_config=None, enable_xformers=False, gradient_checkpointing=False):
    # Load the tokenizer
    print(f'pretrained_model_name_or_path: {pretrained_model_name_or_path}')
    print(f'revision: {revision}')

    text_encoder_class, tokenizer_class = import_model_class_from_model_name_or_path(
        pretrained_model_name_or_path, "text_encoder")

    text_encoder_2_class, tokenizer_2_class = import_model_class_from_model_name_or_path(
        pretrained_model_name_or_path, "text_encoder_2")

    tokenizer_1 = tokenizer_class.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
    text_encoder_1 = text_encoder_class.from_pretrained(pretrained_model_name_or_path,
                                                      subfolder="text_encoder")

    tokenizer_2 = tokenizer_2_class.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer_2")
    text_encoder_2 = text_encoder_2_class.from_pretrained(pretrained_model_name_or_path,
                                                          subfolder="text_encoder_2")

    vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")

    unet = UNet2DConditionModel.from_pretrained(
        pretrained_model_name_or_path,
        subfolder="unet",
    )

    noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
    if noise_scheduler.config.prediction_type == "v_prediction":
        noise_scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
        noise_scheduler.set_timesteps(noise_scheduler.config.num_train_timesteps)

    # Freeze untrained components
    vae.requires_grad_(False)
    text_encoder_1.requires_grad_(False)
    text_encoder_2.requires_grad_(False)

    # Move unet and text_encoders to device and cast to weight_dtype
    unet.to(device, dtype=weight_dtype)
    text_encoder_1.to(device, dtype=weight_dtype)
    text_encoder_2.to(device, dtype=weight_dtype)
    vae.to(device, dtype=weight_dtype)

    if enable_xformers:
        if is_xformers_available():
            import xformers

            xformers_version = version.parse(xformers.__version__)
            if xformers_version == version.parse("0.0.16"):
                ValueError(
                    "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
                )
            unet.enable_xformers_memory_efficient_attention()
        else:
            raise ValueError("xformers is not available. Make sure it is installed correctly")

    if gradient_checkpointing:
        unet.enable_gradient_checkpointing()

    return unet, vae, noise_scheduler, text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2

def sid_sd_sampler(unet, latents, contexts, init_timesteps,  noise_scheduler, 
                         text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2, resolution, dtype=torch.float16,return_images=False, vae=None,guidance_scale=1,num_steps=1,train_sampler=True,num_steps_eval=1):
    #The single step version (num_steps=num_steps_eval=1) has been fully tested; the multi-step version is working in progress

    # Get the text embedding for conditioning
    prompt=contexts
    batch_size = len(prompt)
    text_input_1 = tokenizer_1(
        prompt, 
        padding="max_length", 
        max_length=tokenizer_1.model_max_length, 
        truncation=True, 
        return_tensors="pt"
    )
    text_input_2 = tokenizer_2(
        prompt, 
        padding="max_length", 
        max_length=tokenizer_2.model_max_length, 
        truncation=True, 
        return_tensors="pt"
    )
    add_time_ids = build_condition_input(resolution, latents.device)
    added_cond_kwargs = {"time_ids": add_time_ids.repeat(batch_size, 1)}
    with torch.no_grad():
        prompt_embeds_list = []
        for text_input_ids, text_encoder in zip(
            [text_input_1.input_ids, text_input_2.input_ids], 
            [text_encoder_1, text_encoder_2]
        ):
            prompt_embeds = text_encoder(text_input_ids.to(latents.device), output_hidden_states=True)
            pooled_prompt_embeds = prompt_embeds[0]
            prompt_embeds = prompt_embeds.hidden_states[-2]
            prompt_embeds_list.append(prompt_embeds)

        text_embeddings = torch.concat(prompt_embeds_list, dim=-1).to(latents.device)
        added_cond_kwargs["text_embeds"] = pooled_prompt_embeds.to(latents.device)

    if train_sampler:
        D_x = torch.zeros_like(latents).to(latents.device)
        step_indices = [torch.tensor(0).to(latents.device)]  # Initial step
        for i in range(num_steps):
            noise = latents if i == 0 else torch.randn_like(latents).to(latents.device)
            init_timesteps_i = (init_timesteps*(1-i/num_steps)).to(torch.long)
            latents = noise_scheduler.add_noise(D_x, noise, init_timesteps_i)
            latent_model_input = noise_scheduler.scale_model_input(latents, init_timesteps_i) 
            noise_pred = unet(latent_model_input.to(dtype), init_timesteps_i, encoder_hidden_states=text_embeddings, added_cond_kwargs=added_cond_kwargs).sample
            D_x  = noise_scheduler.step(noise_pred, init_timesteps_i[0], latents,return_dict=True).pred_original_sample  
    else:
        D_x = torch.zeros_like(latents).to(latents.device)
        for i in range(num_steps_eval):
            noise = latents if i == 0 else torch.randn_like(latents).to(latents.device)
            init_timesteps_i = (init_timesteps*(1-i/num_steps_eval)).to(torch.long)
            latents = noise_scheduler.add_noise(D_x, noise, init_timesteps_i)
            latent_model_input = noise_scheduler.scale_model_input(latents, init_timesteps_i) 
            with torch.no_grad():
                noise_pred = unet(latent_model_input.to(dtype), init_timesteps_i, encoder_hidden_states=text_embeddings, added_cond_kwargs=added_cond_kwargs).sample
            D_x  = noise_scheduler.step(noise_pred, init_timesteps_i[0], latents,return_dict=True).pred_original_sample

    if return_images:
        # make sure the VAE is in float32 mode, as it overflows in float16
        needs_upcasting = vae.dtype == torch.float16 and vae.config.force_upcast
        if needs_upcasting:
            upcast_vae(vae=vae)
            D_x = D_x.to(next(iter(vae.post_quant_conv.parameters())).dtype)
        images = vae.decode(D_x / vae.config.scaling_factor, return_dict=False)[0]
        #images = vae.decode(D_x /0.18215).sample
        # cast back to fp16 if needed
        if needs_upcasting:
            vae.to(dtype=torch.float16)
        return images
    else:
        return D_x

def sid_sd_denoise(unet, images, noise, contexts,timesteps,  noise_scheduler, 
                         text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2, resolution,dtype=torch.float16,predict_x0=True,guidance_scale=1):
    # Get the text embedding for conditioning

    prompt = contexts
    batch_size = len(prompt)
    add_time_ids = build_condition_input(resolution, images.device)

    text_input_1 = tokenizer_1(
        prompt,
        padding='max_length',
        max_length=tokenizer_1.model_max_length,
        truncation=True,
        return_tensors='pt',
    )
    text_input_2 = tokenizer_2(
        prompt, 
        padding="max_length", 
        max_length=tokenizer_2.model_max_length, 
        truncation=True, 
        return_tensors="pt"
    )
    with torch.no_grad():
        prompt_embeds_list = []
        for text_input_ids, text_encoder in zip(
            [
                text_input_1.input_ids, 
                text_input_2.input_ids
            ], 
            [text_encoder_1, text_encoder_2]
        ):
            prompt_embeds = text_encoder(text_input_ids.to(images.device), output_hidden_states=True)
            pooled_prompt_embeds = prompt_embeds[0]
            prompt_embeds = prompt_embeds.hidden_states[-2]
            prompt_embeds_list.append(prompt_embeds)

        text_embeddings = torch.concat(prompt_embeds_list, dim=-1).to(images.device)

    added_cond_kwargs_1 = {
        "time_ids": add_time_ids.repeat(batch_size, 1).to(images.device),
        "text_embeds": pooled_prompt_embeds.to(images.device)
    }

    latents = noise_scheduler.add_noise(images, noise, timesteps)
    latent_model_input = noise_scheduler.scale_model_input(latents, timesteps)
    noise_pred_text = unet(latent_model_input.to(dtype), timesteps, encoder_hidden_states=text_embeddings, added_cond_kwargs=added_cond_kwargs_1).sample

    uncond_input_1 = tokenizer_1(
        [''] * batch_size,
        padding='max_length',
        max_length=tokenizer_1.model_max_length,
        truncation=True,
        return_tensors='pt',
    )
    uncond_input_2 = tokenizer_2(
        [''] * batch_size,
        padding='max_length',
        max_length=tokenizer_2.model_max_length,
        truncation=True,
        return_tensors='pt',
    )
    with torch.no_grad():
        prompt_embeds_list = []
        for text_input_ids, text_encoder in zip(
            [
                uncond_input_1.input_ids, 
                uncond_input_2.input_ids
            ], 
            [text_encoder_1, text_encoder_2]
        ):
            prompt_embeds = text_encoder(text_input_ids.to(images.device), output_hidden_states=True)
            pooled_prompt_embeds = prompt_embeds[0]
            prompt_embeds = prompt_embeds.hidden_states[-2]
            prompt_embeds_list.append(prompt_embeds)

        uncond_embeddings = torch.concat(prompt_embeds_list, dim=-1).to(images.device)

    added_cond_kwargs_2 = {
        "time_ids": add_time_ids.repeat(batch_size, 1).to(images.device),
        "text_embeds": pooled_prompt_embeds.to(images.device)
    }

    noise_pred_uncond = unet(latent_model_input.to(dtype), timesteps, encoder_hidden_states=uncond_embeddings, added_cond_kwargs=added_cond_kwargs_2).sample

    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

    if predict_x0:
        D_x = [noise_scheduler.step(n, t, z,return_dict=True).pred_original_sample for n, t, z in zip(noise_pred, timesteps, latents)]
        D_x = torch.stack(D_x)
        return D_x
    else:
        return noise_pred