Open Meshwa428 opened 7 months ago
@Meshwa428 Hi have you solved it yet? I'm facing the same problem. I tried to move kwargs = dict(...)
inside of the iteration below but it throws out a new error:
RuntimeError: device >= 0 && device < num_gpus INTERNAL ASSERT FAILED at "../aten/src/ATen/cuda/CUDAContext.cpp":50, please report a bug to PyTorch. device=, num_gpus=
hello @Superviro, you can use this code, i also did the same thing you did but with some minor changes like adding an if block.
can you please tell me that on which platform/GPU provider are you using to train DMD cause i am trying to find a cheap and best solution to train my own DMD with the data generated using SDXL (creating a distill version of SDXL using the DMD method)
import argparse
import multiprocessing as mp
import os
import time
import signal
import json
import random
import torch
from diffusers import AutoPipelineForText2Image, DEISMultistepScheduler, AutoPipelineForImage2Image
def get_input_ids(pipeline, prompt):
input_ids = pipeline.tokenizer(
prompt, return_tensors="pt", truncation=False
).input_ids.to("cuda")
shape_max_length = input_ids.shape[-1]
return input_ids, shape_max_length
def embed_prompts(pipeline, prompt):
input_ids, shape_max_length = get_input_ids(
pipeline=pipeline, prompt=prompt
)
max_length = pipeline.tokenizer.model_max_length
concat_embeds = []
for i in range(0, input_ids.shape[-1], max_length):
concat_embeds.append(pipeline.text_encoder(
input_ids[:, i: i + max_length])[0])
prompt_embeds = torch.cat(concat_embeds, dim=1)
return prompt_embeds
def run(device_id, job_id, worker_id, n_gpu, n_worker, caption_path, model_id, save_dir, size=None):
global_id = device_id * n_worker + worker_id
n_job = n_gpu * n_worker
signal.signal(signal.SIGINT, signal.SIG_IGN)
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
print(f"[{job_id}] Using device: {device_id} global_id {global_id}")
pipe = AutoPipelineForText2Image.from_pretrained(
model_id, torch_dtype=torch.float16, local_files_only=False)
pipe.safety_checker = None
pipe.scheduler = DEISMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(f"cuda")
pipe.set_progress_bar_config(disable=True)
sd_img2img = AutoPipelineForImage2Image.from_pipe(pipe)
sd_img2img.set_progress_bar_config(disable=True)
# load captions part
with open(caption_path, "r") as f:
prompts = f.readlines()
if size is not None:
prompts = random.choices(prompts, k=size)
chunk_size = len(prompts) // n_job
prompts = prompts[global_id * chunk_size: (global_id + 1) * chunk_size]
print("[{}] process chunk size {}, range [{}:{}]".format(
job_id, chunk_size, global_id * chunk_size, (global_id + 1) * chunk_size))
os.makedirs(save_dir, exist_ok=True)
with open(os.path.join(save_dir, f"meta_{job_id}.json"), "a") as f:
for i, prompt in enumerate(prompts):
prompt = prompt.strip()
batch_size = 1
height = 512
width = 512
num_inference_steps = 3
guidance_scale = 1.155
try:
num_channels_latents = pipe.unet.config.in_channels
except Exception as e:
num_channels_latents = pipe.transformer.config.in_channels
latents = pipe.prepare_latents(
batch_size,
num_channels_latents,
height,
width,
dtype=torch.float16,
device=pipe.device,
generator=torch.Generator().manual_seed(i),
)
prompt_embeds = embed_prompts(
pipeline=pipe,
prompt=prompt,
)
try:
image = pipe(
prompt_embeds=prompt_embeds,
latents=latents,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale).images[0]
image_path = os.path.join("images", f"{job_id}", f"{i}.jpg")
latent_path = os.path.join("latents", f"{job_id}", f"{i}.pt")
print("saved:", image_path)
os.makedirs(os.path.dirname(os.path.join(
save_dir, image_path)), exist_ok=True)
os.makedirs(os.path.dirname(os.path.join(
save_dir, latent_path)), exist_ok=True)
new_image = image.resize((1024, 1024))
final_image = sd_img2img(prompt_embeds=prompt_embeds,
latents=latents,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
image=new_image
).images[0]
final_image.save(os.path.join(save_dir, image_path))
torch.save(latents, os.path.join(save_dir, latent_path))
f.write(json.dumps(
{"image_path": image_path, "latent_path": latent_path, "prompt": prompt, "seed": i}) + "\n")
f.flush()
except:
pass
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--gpus", default=None, type=str,
help="Comma separated list of GPUs to use")
parser.add_argument("--workers", default=1, type=int,
help="Number of workers spawned per GPU (default 1)")
parser.add_argument("--size", default=None, type=int)
parser.add_argument(
"--caption_path", default="diffusion_db_prompts.txt", type=str)
parser.add_argument(
"--model_id", default="runwayml/stable-diffusion-v1-5", type=str)
parser.add_argument(
"--save_dir", default="data/diffusion_db_runwayml_stable-diffusion-v1-5", type=str)
args = parser.parse_args()
return args
def main():
mp.set_start_method("spawn")
args = parse_args()
num_gpus = torch.cuda.device_count()
num_workers = args.workers
if args.gpus is None:
visible_gpus = list(range(num_gpus))
else:
visible_gpus = []
parts = args.gpus.split(",")
for p in parts:
if "-" in p:
lo, hi = p.split("-")
lo, hi = int(lo), int(hi)
assert hi >= lo
visible_gpus.extend(list(range(lo, hi + 1)))
else:
visible_gpus.append(int(p))
visible_gpus = list(set(visible_gpus)) # keep distinct
assert len(visible_gpus) > 0
print("Using GPUs: ", visible_gpus)
jobs = {}
for device_id in visible_gpus:
for i in range(num_workers):
job_id = f"GPU{device_id:02d}-{i}"
print(f"[{job_id}] Launching worker-process...")
if i <= 0:
kwargs = dict(
device_id=device_id,
job_id=job_id,
worker_id=i,
n_gpu=len(visible_gpus),
n_worker=num_workers,
caption_path=args.caption_path,
model_id=args.model_id,
save_dir=args.save_dir,
size=args.size,
)
p = mp.Process(target=run, kwargs=kwargs)
jobs[job_id] = (p, device_id)
p.start()
try:
while True:
time.sleep(1)
for job_id, (job, device_id) in jobs.items():
if job.is_alive():
pass
else:
print(f"[{job_id}] Worker died, cleaning up...")
# remove remaining tar file
os.system(f"rm -r -f -v {args.save_dir}/images/{job_id}")
os.system(f"rm -r -f -v {args.save_dir}/latents/{job_id}")
os.system(f"rm -f -v {args.save_dir}/meta_{job_id}.json")
os.system(f"rm -f -v {args.save_dir}/meta_{job_id}.json")
print(f"[{job_id}] respawning...")
p = mp.Process(target=run, kwargs=kwargs)
jobs[job_id] = (p, device_id)
p.start()
except KeyboardInterrupt:
print("Caught KeyboardInterrupt, terminating workers")
for job_id, (job, device_id) in jobs.items():
job.terminate()
for job_id, (job, device_id) in jobs.items():
print(f"[{job_id}] waiting for exit...")
job.join()
print("done.")
if __name__ == "__main__":
main()
@Meshwa428 Thanks! And I'm using GPUs from my school lab
Damnnnn, your school must be rich 🤑. Thanks for info though
This is the error that i am getting after running this command
Error: