Open koking0 opened 3 days ago
Could you please provide more details, such as hardware used to distill, fp16 or fp32, and how many fake images are iterated to train the generator?
of course
$ wget https://raw.githubusercontent.com/pytorch/pytorch/main/torch/utils/collect_env.py
$ python3.9 collect_env.py
Collecting environment information...
PyTorch version: 2.3.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 18.04.6 LTS (x86_64)
GCC version: (GCC) 8.2.0
Clang version: 3.8.0 (tags/RELEASE_380/final)
CMake version: version 3.28.1
Libc version: glibc-2.27
Python version: 3.9.13 (main, May 23 2022, 22:02:02) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.10.0-1.0.0.28-x86_64-with-glibc2.27
Is CUDA available: True
CUDA runtime version: 11.7.64
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A800-SXM4-80GB
GPU 1: NVIDIA A800-SXM4-80GB
GPU 2: NVIDIA A800-SXM4-80GB
GPU 3: NVIDIA A800-SXM4-80GB
GPU 4: NVIDIA A800-SXM4-80GB
GPU 5: NVIDIA A800-SXM4-80GB
GPU 6: NVIDIA A800-SXM4-80GB
GPU 7: NVIDIA A800-SXM4-80GB
Nvidia driver version: 525.125.06
cuDNN version: Probably one of the following:
/usr/lib/libcudnn.so.8.4.1
/usr/lib/libcudnn_adv_infer.so.8.4.1
/usr/lib/libcudnn_adv_train.so.8.4.1
/usr/lib/libcudnn_cnn_infer.so.8.4.1
/usr/lib/libcudnn_cnn_train.so.8.4.1
/usr/lib/libcudnn_ops_infer.so.8.4.1
/usr/lib/libcudnn_ops_train.so.8.4.1
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 160
On-line CPU(s) list: 0-159
Thread(s) per core: 2
Core(s) per socket: 20
Socket(s): 4
NUMA node(s): 4
Vendor ID: GenuineIntel
CPU family: 6
Model: 85
Model name: Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz
Stepping: 7
CPU MHz: 3200.000
CPU max MHz: 3900.0000
CPU min MHz: 1000.0000
BogoMIPS: 5000.00
Virtualization: VT-x
L1d cache: 32K
L1i cache: 32K
L2 cache: 1024K
L3 cache: 28160K
NUMA node0 CPU(s): 0-19,80-99
NUMA node1 CPU(s): 20-39,100-119
NUMA node2 CPU(s): 40-59,120-139
NUMA node3 CPU(s): 60-79,140-159
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts pku avx512_vnni md_clear flush_l1d arch_capabilities
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.1.3.1
[pip3] nvidia-cuda-cupti-cu12==12.1.105
[pip3] nvidia-cuda-nvrtc-cu12==12.1.105
[pip3] nvidia-cuda-runtime-cu12==12.1.105
[pip3] nvidia-cudnn-cu12==8.9.2.26
[pip3] nvidia-cufft-cu12==11.0.2.54
[pip3] nvidia-curand-cu12==10.3.2.106
[pip3] nvidia-cusolver-cu12==11.4.5.107
[pip3] nvidia-cusparse-cu12==12.1.0.106
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.1.105
[pip3] open-clip-torch==2.24.0
[pip3] torch==2.3.0
[pip3] torchvision==0.18.0
[pip3] triton==2.3.0
[conda] Could not collect
I mainly distilled the following model: https://civitai.com/models/112902/dreamshaper-xl
The main code changes are as follows.
torchrun --standalone --nproc_per_node=8 sid_train.py \
--outdir 'image_experiment/sid-lsg-train-runs/' \
--train_mode 1 \
--cfg_train_fake 2 \
--cfg_eval_fake 2 \
--cfg_eval_real 2 \
--optimizer 'adam' \
--data_prompt_text '/root/paddlejob/workspace/SiD-LSG-main/datasets' \
--resolution 1024 \
--alpha 1 \
--init_timestep 625 \
--batch 512 \
--fp16 1 \
--batch-gpu 1 \
--sd_model "/root/paddlejob/workspace/models/DreamShaper_xl_v2_1" \
--tick 2 \
--snap 50 \
--dump 100 \
--lr 0.000001 \
--glr 0.000001 \
--duration 10 \
--enable_xformers 1 \
--gradient_checkpointing 1 \
--ema 0
if opts.optimizer=='adam':
c.fake_score_optimizer_kwargs = dnnlib.EasyDict(class_name='bitsandbytes.optim.Adam8bit', lr=opts.lr, betas=[0.0, 0.999], eps = 1e-8 if not opts.fp16 else 1e-6)
c.g_optimizer_kwargs = dnnlib.EasyDict(class_name='bitsandbytes.optim.Adam8bit', lr=opts.glr, betas=[0.0, 0.999], eps = 1e-8 if not opts.fp16 else 1e-6)
else:
#this is another optimizer to choose; it could provide better performance, but we have not carefully tested it yet
assert opts.optimizer=='adamw'
c.fake_score_optimizer_kwargs = dnnlib.EasyDict(class_name='torch.optim.AdamW', lr=opts.lr, betas=[0.0, 0.999], eps = 1e-8 if not opts.fp16 else 1e-6,weight_decay=0.01)
c.g_optimizer_kwargs = dnnlib.EasyDict(class_name='torch.optim.AdamW', lr=opts.glr, betas=[0.0, 0.999], eps = 1e-8 if not opts.fp16 else 1e-6,weight_decay=0.01)
c.init_timestep = opts.init_timestep
#Distill Stable Diffusion with SiD-LSG
if train_mode:
#Use barrier if needs to downloading the weights from internet and save to cache
if dist.get_rank() != 0:
torch.distributed.barrier()
if dtype==torch.float16:
#if the fp16 checkpoint variant is not available, you can load the fp32 version and then convert them into fp16
unet, vae, noise_scheduler, text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2 = load_sd15(pretrained_model_name_or_path=pretrained_model_name_or_path, pretrained_vae_model_name_or_path=None,
device=device, weight_dtype=dtype, variant="fp16", enable_xformers=enable_xformers, lora_config=lora_config)
else:
unet, vae, noise_scheduler, text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2 = load_sd15(pretrained_model_name_or_path=pretrained_model_name_or_path, pretrained_vae_model_name_or_path=None,
device=device, weight_dtype=dtype, enable_xformers=enable_xformers, lora_config=lora_config)
if dist.get_rank() == 0:
torch.distributed.barrier()
dist.print0('Loading network completed')
dist.print0('Noise scheduler:')
dist.print0(noise_scheduler)
# Initialize.
start_time = time.time()
np.random.seed((seed * dist.get_world_size() + dist.get_rank()) % (1 << 31))
torch.manual_seed(np.random.randint(1 << 31))
torch.backends.cudnn.benchmark = cudnn_benchmark
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
# Select batch size per GPU. Used for gradient accumulation
batch_gpu_total = batch_size // dist.get_world_size()
if batch_gpu is None or batch_gpu > batch_gpu_total:
batch_gpu = batch_gpu_total
num_accumulation_rounds = batch_gpu_total // batch_gpu
assert batch_size == batch_gpu * num_accumulation_rounds * dist.get_world_size()
# Parameters for latent diffusion
latent_img_channels = 4
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
latent_resolution = resolution // vae_scale_factor
#Prepare the random noise used for example image generation during training
if dist.get_rank() == 0:
original_seed = torch.initial_seed()
# Set a temporary random seed
temporary_seed = 2024
torch.manual_seed(temporary_seed)
grid_size, images, contexts = setup_snapshot_image_grid(training_set=dataset_obj)
#contexts = [""] * len(contexts)
grid_z = torch.randn([len(contexts), latent_img_channels, latent_resolution, latent_resolution], device=device, dtype=dtype)
grid_z = grid_z.split(batch_gpu)
grid_c = split_list(contexts, batch_gpu)
# Revert back to the original random seed
torch.manual_seed(original_seed)
dataset_prompt_text_obj = dnnlib.util.construct_class_by_name(**dataset_prompt_text_kwargs) # subclass of training.dataset.Dataset
dataset_prompt_text_sampler = misc.InfiniteSampler(dataset=dataset_prompt_text_obj, rank=dist.get_rank(), num_replicas=dist.get_world_size(), seed=seed)
dataset_prompt_text_iterator = iter(torch.utils.data.DataLoader(dataset=dataset_prompt_text_obj, sampler=dataset_prompt_text_sampler, batch_size=batch_gpu, **data_loader_kwargs))
dist.print0("Example text prompts used for distillation:")
for _i in range(16):
dist.print0(_i)
_,contexts = next(dataset_prompt_text_iterator)
dist.print0(contexts)
#Initilize true score net, fake score net, and generator
true_score = unet
true_score.eval().requires_grad_(False).to(device)
fake_score = copy.deepcopy(true_score).train().requires_grad_(True).to(device)
G = copy.deepcopy(true_score).train().requires_grad_(True).to(device)
# Setup optimizer.
dist.print0('Setting up optimizer...')
fake_score_optimizer = dnnlib.util.construct_class_by_name(params=fake_score.parameters(), **fake_score_optimizer_kwargs) # subclass of torch.optim.Optimizer
g_optimizer = dnnlib.util.construct_class_by_name(params=G.parameters(), **g_optimizer_kwargs) # subclass of torch.optim.Optimizer
# Resume training from previous snapshot.
if resume_training is not None:
dist.print0('checkpoint path:',resume_training)
data = torch.load(resume_training, map_location=torch.device('cpu'))
misc.copy_params_and_buffers(src_module=data['fake_score'], dst_module=fake_score, require_all=True)
misc.copy_params_and_buffers(src_module=data['G'], dst_module=G, require_all=True)
if ema_halflife_kimg>0:
G_ema = copy.deepcopy(G).eval().requires_grad_(False)
misc.copy_params_and_buffers(src_module=data['G_ema'], dst_module=G_ema, require_all=True)
G_ema.eval().requires_grad_(False)
else:
G_ema=G
fake_score_optimizer.load_state_dict(data['fake_score_optimizer_state'])
g_optimizer.load_state_dict(data['g_optimizer_state'])
del data # conserve memory
dist.print0('Loading checkpoint completed')
torch.distributed.barrier()
# Setup GPU parallel computing.
dist.print0('Setting up GPU parallel computing')
fake_score_ddp = torch.nn.parallel.DistributedDataParallel(fake_score, device_ids=[device], broadcast_buffers=False,find_unused_parameters=False)
G_ddp = torch.nn.parallel.DistributedDataParallel(G, device_ids=[device], broadcast_buffers=False,find_unused_parameters=False)
else:
# Setup GPU parallel computing.
dist.print0('Setting up GPU parallel computing')
fake_score_ddp = torch.nn.parallel.DistributedDataParallel(fake_score, device_ids=[device], broadcast_buffers=False,find_unused_parameters=False)
G_ddp = torch.nn.parallel.DistributedDataParallel(G, device_ids=[device], broadcast_buffers=False,find_unused_parameters=False)
if ema_halflife_kimg>0:
G_ema = copy.deepcopy(G).eval().requires_grad_(False)
else:
G_ema = G
fake_score_ddp.eval().requires_grad_(False)
G_ddp.eval().requires_grad_(False)
# Train.
dist.print0(f'Training for {total_kimg} kimg...')
dist.print0()
cur_nimg = resume_kimg * 1000
cur_tick = 0
tick_start_nimg = cur_nimg
tick_start_time = time.time()
maintenance_time = tick_start_time - start_time
dist.update_progress(cur_nimg // 1000, total_kimg)
stats_jsonl = None
stats_metrics = dict()
if resume_training is None:
if dist.get_rank() == 0:
print('Exporting sample real images...')
save_image_grid(img=images, fname=os.path.join(run_dir, 'reals.png'), drange=[0,255], grid_size=grid_size)
print('Text prompts for example images:')
for c in grid_c:
dist.print0(c)
print('Exporting sample fake images at initialization...')
images = [sid_sd_sampler(unet=G_ema,latents=z,contexts=c,init_timesteps = init_timestep * torch.ones((len(c),), device=device, dtype=torch.long),
noise_scheduler=noise_scheduler,
text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2,
tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2,
resolution=resolution,dtype=dtype,return_images=True, vae=vae,num_steps=num_steps,train_sampler=False,num_steps_eval=1) for z, c in zip(grid_z, grid_c)]
images = torch.cat(images).cpu().numpy()
save_image_grid(img=images, fname=os.path.join(run_dir, 'fakes_init.png'), drange=[-1,1], grid_size=grid_size)
del images
torch.distributed.barrier()
dist.print0('Start Running')
while True:
torch.cuda.empty_cache()
gc.collect()
G_ddp.eval().requires_grad_(False)
#----------------------------------------------------------------------------------------------
# Update Fake Score Network
fake_score_ddp.train().requires_grad_(True)
fake_score_optimizer.zero_grad(set_to_none=True)
for round_idx in range(num_accumulation_rounds):
_, contexts = next(dataset_prompt_text_iterator)
if use_context_dropout_train_fake:
bool_tensor = torch.rand(len(contexts)) < 0.1
# Update contexts based on bool_tensor
contexts = ["" if flag else context for flag, context in zip(bool_tensor.tolist(), contexts)]
#print(contexts)
z = torch.randn([len(contexts), latent_img_channels, latent_resolution, latent_resolution], device=device, dtype=dtype)
noise = torch.randn_like(z)
# Initialize timesteps
init_timesteps = init_timestep * torch.ones((len(contexts),), device=device, dtype=torch.long)
# Generate fake images (stop generator gradient)
with misc.ddp_sync(G_ddp, False):
with torch.no_grad():
images = sid_sd_sampler(unet=G_ddp,latents=z,contexts=contexts,init_timesteps=init_timesteps,
noise_scheduler=noise_scheduler,
text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2,
tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2,
resolution=resolution,dtype=dtype,return_images=False, vae=None,num_steps=num_steps)
timesteps = torch.randint(tmin, tmax, (len(contexts),), device=device, dtype=torch.long)
# Compute loss for fake score network
with misc.ddp_sync(fake_score_ddp, (round_idx == num_accumulation_rounds - 1)):
#Denoised fake images (stop generator gradient) under fake score network, using guidance scale: kappa1=cfg_eval_train
noise_fake = sid_sd_denoise(unet=fake_score_ddp,images=images,noise=noise,contexts=contexts,timesteps=timesteps,
noise_scheduler=noise_scheduler,
text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2,
tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2,
resolution=resolution,dtype=dtype,predict_x0=False,guidance_scale=cfg_train_fake)
nan_mask = torch.isnan(noise_fake).flatten(start_dim=1).any(dim=1)
if noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(images, noise, timesteps)
nan_mask = nan_mask | torch.isnan(target).flatten(start_dim=1).any(dim=1)
# Check if there are any NaN values present
if nan_mask.any():
# Invert the nan_mask to get a mask of samples without NaNs
non_nan_mask = ~nan_mask
# Filter out samples with NaNs from y_real and y_fake
noise_fake = noise_fake[non_nan_mask]
noise = noise[non_nan_mask]
if noise_scheduler.config.prediction_type == "v_prediction":
target = target[non_nan_mask]
if noise_scheduler.config.prediction_type == "v_prediction":
loss = (noise_fake-target)**2
snr = compute_snr(noise_scheduler, timesteps)
loss = loss * snr/(snr+1)
else:
loss = (noise_fake-noise)**2
loss=loss.sum().mul(loss_scaling / batch_gpu_total)
del images
if len(noise) > 0:
loss.backward()
loss_fake_score_print = loss.item()
training_stats.report('fake_score_Loss/loss', loss_fake_score_print)
fake_score_ddp.eval().requires_grad_(False)
# Update fake score network
for param in fake_score.parameters():
if param.grad is not None:
torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)
fake_score_optimizer.step()
#----------------------------------------------------------------------------------------------
# Update One-Step Generator Network
G_ddp.train().requires_grad_(True)
g_optimizer.zero_grad(set_to_none=True)
for round_idx in range(num_accumulation_rounds):
_, contexts = next(dataset_prompt_text_iterator)
if use_context_dropout_train_G:
bool_tensor = torch.rand(len(contexts)) < 0.1
# Update contexts based on bool_tensor
contexts = ["" if flag else context for flag, context in zip(bool_tensor.tolist(), contexts)]
z = torch.randn([len(contexts), latent_img_channels, latent_resolution, latent_resolution], device=device, dtype=dtype)
noise = torch.randn_like(z)
# initialize timesteps
init_timesteps = init_timestep * torch.ones((len(contexts),), device=device, dtype=torch.long)
timesteps = torch.randint(tmin, tmax, (len(contexts),), device=device, dtype=torch.long)
# Generate fake images (track generator gradient)
with misc.ddp_sync(G_ddp, (round_idx == num_accumulation_rounds - 1)):
images = sid_sd_sampler(unet=G_ddp,latents=z,contexts=contexts,init_timesteps=init_timesteps,
noise_scheduler=noise_scheduler,
text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2,
tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2,
resolution=resolution,dtype=dtype,return_images=False,num_steps=num_steps)
# Compute loss for generator
with misc.ddp_sync(fake_score_ddp, False):
#Denoised fake images (track generator gradient) under fake score network, using guidance scale: kappa2=kappa3=cfg_eval_fake
y_fake = sid_sd_denoise(unet=fake_score_ddp,images=images,noise=noise,contexts=contexts,timesteps=timesteps,
noise_scheduler=noise_scheduler,
text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2,
tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2,
resolution=resolution,dtype=dtype,guidance_scale=cfg_eval_fake)
#Denoised fake images (track generator gradient) under pretrained score network, using guidance scale: kappa4=cfg_eval_real
y_real = sid_sd_denoise(unet=true_score,images=images,noise=noise,contexts=contexts,timesteps=timesteps,
noise_scheduler=noise_scheduler,
text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2,
tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2,
resolution=resolution,dtype=dtype,guidance_scale=cfg_eval_real)
nan_mask_images = torch.isnan(images).flatten(start_dim=1).any(dim=1)
nan_mask_y_real = torch.isnan(y_real).flatten(start_dim=1).any(dim=1)
nan_mask_y_fake = torch.isnan(y_fake).flatten(start_dim=1).any(dim=1)
nan_mask = nan_mask_images | nan_mask_y_real | nan_mask_y_fake
# Check if there are any NaN values present
if nan_mask.any():
# Invert the nan_mask to get a mask of samples without NaNs
non_nan_mask = ~nan_mask
# Filter out samples with NaNs from y_real and y_fake
images = images[non_nan_mask]
y_real = y_real[non_nan_mask]
y_fake = y_fake[non_nan_mask]
with torch.no_grad():
weight_factor = abs(images.to(dtype) - y_real.to(dtype)).mean(dim=[1, 2, 3], keepdim=True).clip(min=0.00001)
if alpha==1:
loss = (y_real - y_fake) * (y_fake - images) / weight_factor
else:
loss = (y_real - y_fake) * ((y_real - images) - alpha * (y_real - y_fake)) / weight_factor
loss=loss.sum().mul(loss_scaling_G / batch_gpu_total)
if len(y_real) > 0:
loss.backward()
lossG_print = loss.item()
training_stats.report('G_Loss/loss', lossG_print)
G_ddp.eval().requires_grad_(False)
# Update generator
for param in G.parameters():
if param.grad is not None:
torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)
#Apply gradient clipping under fp16 to prevent suddern divergence
if dtype == torch.float16 and len(y_real) > 0:
torch.nn.utils.clip_grad_value_(G.parameters(), 1)
g_optimizer.step()
if ema_halflife_kimg>0:
# Update EMA.
ema_halflife_nimg = ema_halflife_kimg * 1000
if ema_rampup_ratio is not None:
ema_halflife_nimg = min(ema_halflife_nimg, cur_nimg * ema_rampup_ratio)
ema_beta = 0.5 ** (batch_size / max(ema_halflife_nimg, 1e-8))
for p_ema, p_true_score in zip(G_ema.parameters(), G.parameters()):
#p_ema.copy_(p_true_score.detach().lerp(p_ema, ema_beta))
with torch.no_grad():
p_ema.copy_(p_true_score.detach().lerp(p_ema, ema_beta))
else:
G_ema=G
cur_nimg += batch_size
done = (cur_nimg >= total_kimg * 1000)
if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000):
continue
# Print status line, accumulating the same information in training_stats.
tick_end_time = time.time()
fields = []
fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"]
fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<9.1f}"]
fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"]
fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"]
fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"]
fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"]
fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"]
fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"]
fields += [f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}"]
fields += [f"loss_fake_score {training_stats.report0('fake_score_Loss/loss', loss_fake_score_print):<6.2f}"]
fields += [f"loss_G {training_stats.report0('G_Loss/loss', lossG_print):<6.2f}"]
torch.cuda.reset_peak_memory_stats()
dist.print0(' '.join(fields))
# Check for abort.
if (not done) and dist.should_stop():
done = True
dist.print0()
dist.print0('Aborting...')
if (snapshot_ticks is not None) and (done or cur_tick % snapshot_ticks == 0 or cur_tick in [2,4,10,20,30,40,50,60,70,80,90,100]):
dist.print0('Exporting sample images...')
if dist.get_rank() == 0:
for num_steps_eval in [1,2,4]:
#While the generator is primarily trained to generate images in a single step, it can also be utilized in a multi-step setting during evaluation.
#To do: Distill a multi-step generator that is optimized for multi-step settings
images = [sid_sd_sampler(unet=G_ema,latents=z,contexts=c,init_timesteps=init_timestep * torch.ones((len(c),), device=device, dtype=torch.long),
noise_scheduler=noise_scheduler,
text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2,
tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2,
resolution=resolution,dtype=dtype,return_images=True, vae=vae,num_steps=num_steps,train_sampler=False,num_steps_eval=num_steps_eval) for z, c in zip(grid_z, grid_c)]
images = torch.cat(images).cpu().numpy()
# if cur_tick==0:
# dist.print0(contexts[0])
# #dist.print0(images[0])
save_image_grid(img=images, fname=os.path.join(run_dir, f'fakes_{alpha:03f}_{cur_nimg//1000:06d}_{num_steps_eval:d}.png'), drange=[-1,1], grid_size=grid_size)
del images
if cur_tick>0:
dist.print0('Evaluating metrics...')
dist.print0(metric_pt_path)
if metrics is not None:
for metric in metrics:
result_dict = metric_main.calc_metric(metric=metric, metric_pt_path=metric_pt_path, metric_open_clip_path=metric_open_clip_path, metric_clip_path=metric_clip_path,
G=partial(sid_sd_sampler,unet=G_ema,noise_scheduler=noise_scheduler,
text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2,
tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2,
resolution=resolution,dtype=dtype,return_images=True,vae=vae,num_steps=num_steps,train_sampler=False,num_steps_eval=1),
init_timestep=init_timestep,
dataset_kwargs=dataset_kwargs, num_gpus=dist.get_world_size(), rank=dist.get_rank(), local_rank=dist.get_local_rank(), device=device)
if dist.get_rank() == 0:
print(result_dict.results)
metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=f'fakes_{alpha:03f}_{cur_nimg//1000:06d}.png', alpha=alpha)
stats_metrics.update(result_dict.results)
data = dict(ema=G_ema)
for key, value in data.items():
if isinstance(value, torch.nn.Module):
value = copy.deepcopy(value).eval().requires_grad_(False)
# misc.check_ddp_consistency(value)
data[key] = value.cpu()
del value # conserve memory
if dist.get_rank() == 0:
save_data(data=data, fname=os.path.join(run_dir, f'network-snapshot-{alpha:03f}-{cur_nimg//1000:06d}.pkl'))
del data # conserve memory
if (state_dump_ticks is not None) and (done or cur_tick % state_dump_ticks == 0) and cur_tick != 0 and dist.get_rank() == 0:
dist.print0(f'saving checkpoint: training-state-{cur_nimg//1000:06d}.pt')
save_pt(pt=dict(fake_score=fake_score, G=G, G_ema=G_ema, fake_score_optimizer_state=fake_score_optimizer.state_dict(), g_optimizer_state=g_optimizer.state_dict()), fname=os.path.join(run_dir, f'training-state-{cur_nimg//1000:06d}.pt'))
# Update logs.
training_stats.default_collector.update()
if dist.get_rank() == 0:
if stats_jsonl is None:
append_line(jsonl_line=json.dumps(dict(training_stats.default_collector.as_dict(), timestamp=time.time())) + '\n', fname=os.path.join(run_dir, f'stats_{alpha:03f}.jsonl'))
dist.update_progress(cur_nimg // 1000, total_kimg)
# Update state.
cur_tick += 1
tick_start_nimg = cur_nimg
tick_start_time = time.time()
maintenance_time = tick_start_time - tick_end_time
if done:
break
# Done.
dist.print0()
dist.print0('Exiting...')
def load_sd15(pretrained_model_name_or_path, pretrained_vae_model_name_or_path, device, weight_dtype,
revision=None, variant=None, lora_config=None, enable_xformers=False, gradient_checkpointing=False):
# Load the tokenizer
print(f'pretrained_model_name_or_path: {pretrained_model_name_or_path}')
print(f'revision: {revision}')
text_encoder_class, tokenizer_class = import_model_class_from_model_name_or_path(
pretrained_model_name_or_path, "text_encoder")
text_encoder_2_class, tokenizer_2_class = import_model_class_from_model_name_or_path(
pretrained_model_name_or_path, "text_encoder_2")
tokenizer_1 = tokenizer_class.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
text_encoder_1 = text_encoder_class.from_pretrained(pretrained_model_name_or_path,
subfolder="text_encoder")
tokenizer_2 = tokenizer_2_class.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer_2")
text_encoder_2 = text_encoder_2_class.from_pretrained(pretrained_model_name_or_path,
subfolder="text_encoder_2")
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(
pretrained_model_name_or_path,
subfolder="unet",
)
noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
if noise_scheduler.config.prediction_type == "v_prediction":
noise_scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
noise_scheduler.set_timesteps(noise_scheduler.config.num_train_timesteps)
# Freeze untrained components
vae.requires_grad_(False)
text_encoder_1.requires_grad_(False)
text_encoder_2.requires_grad_(False)
# Move unet and text_encoders to device and cast to weight_dtype
unet.to(device, dtype=weight_dtype)
text_encoder_1.to(device, dtype=weight_dtype)
text_encoder_2.to(device, dtype=weight_dtype)
vae.to(device, dtype=weight_dtype)
if enable_xformers:
if is_xformers_available():
import xformers
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
ValueError(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
if gradient_checkpointing:
unet.enable_gradient_checkpointing()
return unet, vae, noise_scheduler, text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2
def sid_sd_sampler(unet, latents, contexts, init_timesteps, noise_scheduler,
text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2, resolution, dtype=torch.float16,return_images=False, vae=None,guidance_scale=1,num_steps=1,train_sampler=True,num_steps_eval=1):
#The single step version (num_steps=num_steps_eval=1) has been fully tested; the multi-step version is working in progress
# Get the text embedding for conditioning
prompt=contexts
batch_size = len(prompt)
text_input_1 = tokenizer_1(
prompt,
padding="max_length",
max_length=tokenizer_1.model_max_length,
truncation=True,
return_tensors="pt"
)
text_input_2 = tokenizer_2(
prompt,
padding="max_length",
max_length=tokenizer_2.model_max_length,
truncation=True,
return_tensors="pt"
)
add_time_ids = build_condition_input(resolution, latents.device)
added_cond_kwargs = {"time_ids": add_time_ids.repeat(batch_size, 1)}
with torch.no_grad():
prompt_embeds_list = []
for text_input_ids, text_encoder in zip(
[text_input_1.input_ids, text_input_2.input_ids],
[text_encoder_1, text_encoder_2]
):
prompt_embeds = text_encoder(text_input_ids.to(latents.device), output_hidden_states=True)
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds_list.append(prompt_embeds)
text_embeddings = torch.concat(prompt_embeds_list, dim=-1).to(latents.device)
added_cond_kwargs["text_embeds"] = pooled_prompt_embeds.to(latents.device)
if train_sampler:
D_x = torch.zeros_like(latents).to(latents.device)
step_indices = [torch.tensor(0).to(latents.device)] # Initial step
for i in range(num_steps):
noise = latents if i == 0 else torch.randn_like(latents).to(latents.device)
init_timesteps_i = (init_timesteps*(1-i/num_steps)).to(torch.long)
latents = noise_scheduler.add_noise(D_x, noise, init_timesteps_i)
latent_model_input = noise_scheduler.scale_model_input(latents, init_timesteps_i)
noise_pred = unet(latent_model_input.to(dtype), init_timesteps_i, encoder_hidden_states=text_embeddings, added_cond_kwargs=added_cond_kwargs).sample
D_x = noise_scheduler.step(noise_pred, init_timesteps_i[0], latents,return_dict=True).pred_original_sample
else:
D_x = torch.zeros_like(latents).to(latents.device)
for i in range(num_steps_eval):
noise = latents if i == 0 else torch.randn_like(latents).to(latents.device)
init_timesteps_i = (init_timesteps*(1-i/num_steps_eval)).to(torch.long)
latents = noise_scheduler.add_noise(D_x, noise, init_timesteps_i)
latent_model_input = noise_scheduler.scale_model_input(latents, init_timesteps_i)
with torch.no_grad():
noise_pred = unet(latent_model_input.to(dtype), init_timesteps_i, encoder_hidden_states=text_embeddings, added_cond_kwargs=added_cond_kwargs).sample
D_x = noise_scheduler.step(noise_pred, init_timesteps_i[0], latents,return_dict=True).pred_original_sample
if return_images:
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = vae.dtype == torch.float16 and vae.config.force_upcast
if needs_upcasting:
upcast_vae(vae=vae)
D_x = D_x.to(next(iter(vae.post_quant_conv.parameters())).dtype)
images = vae.decode(D_x / vae.config.scaling_factor, return_dict=False)[0]
#images = vae.decode(D_x /0.18215).sample
# cast back to fp16 if needed
if needs_upcasting:
vae.to(dtype=torch.float16)
return images
else:
return D_x
def sid_sd_denoise(unet, images, noise, contexts,timesteps, noise_scheduler,
text_encoder_1, text_encoder_2, tokenizer_1, tokenizer_2, resolution,dtype=torch.float16,predict_x0=True,guidance_scale=1):
# Get the text embedding for conditioning
prompt = contexts
batch_size = len(prompt)
add_time_ids = build_condition_input(resolution, images.device)
text_input_1 = tokenizer_1(
prompt,
padding='max_length',
max_length=tokenizer_1.model_max_length,
truncation=True,
return_tensors='pt',
)
text_input_2 = tokenizer_2(
prompt,
padding="max_length",
max_length=tokenizer_2.model_max_length,
truncation=True,
return_tensors="pt"
)
with torch.no_grad():
prompt_embeds_list = []
for text_input_ids, text_encoder in zip(
[
text_input_1.input_ids,
text_input_2.input_ids
],
[text_encoder_1, text_encoder_2]
):
prompt_embeds = text_encoder(text_input_ids.to(images.device), output_hidden_states=True)
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds_list.append(prompt_embeds)
text_embeddings = torch.concat(prompt_embeds_list, dim=-1).to(images.device)
added_cond_kwargs_1 = {
"time_ids": add_time_ids.repeat(batch_size, 1).to(images.device),
"text_embeds": pooled_prompt_embeds.to(images.device)
}
latents = noise_scheduler.add_noise(images, noise, timesteps)
latent_model_input = noise_scheduler.scale_model_input(latents, timesteps)
noise_pred_text = unet(latent_model_input.to(dtype), timesteps, encoder_hidden_states=text_embeddings, added_cond_kwargs=added_cond_kwargs_1).sample
uncond_input_1 = tokenizer_1(
[''] * batch_size,
padding='max_length',
max_length=tokenizer_1.model_max_length,
truncation=True,
return_tensors='pt',
)
uncond_input_2 = tokenizer_2(
[''] * batch_size,
padding='max_length',
max_length=tokenizer_2.model_max_length,
truncation=True,
return_tensors='pt',
)
with torch.no_grad():
prompt_embeds_list = []
for text_input_ids, text_encoder in zip(
[
uncond_input_1.input_ids,
uncond_input_2.input_ids
],
[text_encoder_1, text_encoder_2]
):
prompt_embeds = text_encoder(text_input_ids.to(images.device), output_hidden_states=True)
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds_list.append(prompt_embeds)
uncond_embeddings = torch.concat(prompt_embeds_list, dim=-1).to(images.device)
added_cond_kwargs_2 = {
"time_ids": add_time_ids.repeat(batch_size, 1).to(images.device),
"text_embeds": pooled_prompt_embeds.to(images.device)
}
noise_pred_uncond = unet(latent_model_input.to(dtype), timesteps, encoder_hidden_states=uncond_embeddings, added_cond_kwargs=added_cond_kwargs_2).sample
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if predict_x0:
D_x = [noise_scheduler.step(n, t, z,return_dict=True).pred_original_sample for n, t, z in zip(noise_pred, timesteps, latents)]
D_x = torch.stack(D_x)
return D_x
else:
return noise_pred
Has anyone encountered the following problem? I used SiD-LSG to distill an SDXL model (made some code adaptations to the text-encoder), and some color spots appeared on the face, which were very obvious when zoomed in.