Zeqiang-Lai / OpenDMD

Open source implementation and models of One-step Diffusion with Distribution Matching Distillation
GNU General Public License v2.0
94 stars 11 forks source link

build_regression_data.py file throws out an error!!!!! #5

Open Meshwa428 opened 2 months ago

Meshwa428 commented 2 months ago

This is the error that i am getting after running this command

!python build_regression_data.py --model_id stabilityai/sdxl-turbo

Error:

2024-04-24 05:41:02.837908: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-24 05:41:02.837966: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-24 05:41:02.839433: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-04-24 05:41:03.930974: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Using GPUs:  [0]
Traceback (most recent call last):
  File "/content/OpenDMD/build_regression_data.py", line 163, in <module>
    main()
  File "/content/OpenDMD/build_regression_data.py", line 114, in main
    device_id=device_id,
UnboundLocalError: local variable 'device_id' referenced before assignment
Superviro commented 2 months ago

@Meshwa428 Hi have you solved it yet? I'm facing the same problem. I tried to move kwargs = dict(...) inside of the iteration below but it throws out a new error: RuntimeError: device >= 0 && device < num_gpus INTERNAL ASSERT FAILED at "../aten/src/ATen/cuda/CUDAContext.cpp":50, please report a bug to PyTorch. device=, num_gpus=

Meshwa428 commented 2 months ago

hello @Superviro, you can use this code, i also did the same thing you did but with some minor changes like adding an if block.

can you please tell me that on which platform/GPU provider are you using to train DMD cause i am trying to find a cheap and best solution to train my own DMD with the data generated using SDXL (creating a distill version of SDXL using the DMD method)

import argparse
import multiprocessing as mp
import os
import time
import signal
import json
import random

import torch
from diffusers import AutoPipelineForText2Image, DEISMultistepScheduler, AutoPipelineForImage2Image

def get_input_ids(pipeline, prompt):
    input_ids = pipeline.tokenizer(
        prompt, return_tensors="pt", truncation=False
    ).input_ids.to("cuda")
    shape_max_length = input_ids.shape[-1]

    return input_ids, shape_max_length

def embed_prompts(pipeline, prompt):
    input_ids, shape_max_length = get_input_ids(
        pipeline=pipeline, prompt=prompt
    )

    max_length = pipeline.tokenizer.model_max_length

    concat_embeds = []

    for i in range(0, input_ids.shape[-1], max_length):
        concat_embeds.append(pipeline.text_encoder(
            input_ids[:, i: i + max_length])[0])

    prompt_embeds = torch.cat(concat_embeds, dim=1)

    return prompt_embeds

def run(device_id, job_id, worker_id, n_gpu, n_worker, caption_path, model_id, save_dir, size=None):
    global_id = device_id * n_worker + worker_id
    n_job = n_gpu * n_worker
    signal.signal(signal.SIGINT, signal.SIG_IGN)

    os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
    print(f"[{job_id}] Using device: {device_id} global_id {global_id}")

    pipe = AutoPipelineForText2Image.from_pretrained(
        model_id, torch_dtype=torch.float16, local_files_only=False)
    pipe.safety_checker = None
    pipe.scheduler = DEISMultistepScheduler.from_config(pipe.scheduler.config)
    pipe = pipe.to(f"cuda")
    pipe.set_progress_bar_config(disable=True)
    sd_img2img = AutoPipelineForImage2Image.from_pipe(pipe)
    sd_img2img.set_progress_bar_config(disable=True)

    # load captions part
    with open(caption_path, "r") as f:
        prompts = f.readlines()
        if size is not None:
            prompts = random.choices(prompts, k=size)

    chunk_size = len(prompts) // n_job
    prompts = prompts[global_id * chunk_size: (global_id + 1) * chunk_size]
    print("[{}] process chunk size {}, range [{}:{}]".format(
        job_id, chunk_size, global_id * chunk_size, (global_id + 1) * chunk_size))

    os.makedirs(save_dir, exist_ok=True)
    with open(os.path.join(save_dir, f"meta_{job_id}.json"), "a") as f:
        for i, prompt in enumerate(prompts):
            prompt = prompt.strip()

            batch_size = 1
            height = 512
            width = 512
            num_inference_steps = 3
            guidance_scale = 1.155
            try:
                num_channels_latents = pipe.unet.config.in_channels
            except Exception as e:
                num_channels_latents = pipe.transformer.config.in_channels

            latents = pipe.prepare_latents(
                batch_size,
                num_channels_latents,
                height,
                width,
                dtype=torch.float16,
                device=pipe.device,
                generator=torch.Generator().manual_seed(i),
            )

            prompt_embeds = embed_prompts(
                pipeline=pipe,
                prompt=prompt,
            )

            try:
                image = pipe(
                    prompt_embeds=prompt_embeds,
                    latents=latents,
                    num_inference_steps=num_inference_steps,
                    guidance_scale=guidance_scale).images[0]

                image_path = os.path.join("images", f"{job_id}", f"{i}.jpg")
                latent_path = os.path.join("latents", f"{job_id}", f"{i}.pt")

                print("saved:", image_path)

                os.makedirs(os.path.dirname(os.path.join(
                    save_dir, image_path)), exist_ok=True)
                os.makedirs(os.path.dirname(os.path.join(
                    save_dir, latent_path)), exist_ok=True)

                new_image = image.resize((1024, 1024))

                final_image = sd_img2img(prompt_embeds=prompt_embeds,
                                         latents=latents,
                                         num_inference_steps=num_inference_steps,
                                         guidance_scale=guidance_scale,
                                         image=new_image
                                         ).images[0]

                final_image.save(os.path.join(save_dir, image_path))
                torch.save(latents, os.path.join(save_dir, latent_path))

                f.write(json.dumps(
                    {"image_path": image_path, "latent_path": latent_path, "prompt": prompt, "seed": i}) + "\n")
                f.flush()
            except:
                pass

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--gpus", default=None, type=str,
                        help="Comma separated list of GPUs to use")
    parser.add_argument("--workers", default=1, type=int,
                        help="Number of workers spawned per GPU (default 1)")
    parser.add_argument("--size", default=None, type=int)
    parser.add_argument(
        "--caption_path", default="diffusion_db_prompts.txt", type=str)
    parser.add_argument(
        "--model_id", default="runwayml/stable-diffusion-v1-5", type=str)
    parser.add_argument(
        "--save_dir", default="data/diffusion_db_runwayml_stable-diffusion-v1-5", type=str)
    args = parser.parse_args()
    return args

def main():
    mp.set_start_method("spawn")
    args = parse_args()

    num_gpus = torch.cuda.device_count()
    num_workers = args.workers
    if args.gpus is None:
        visible_gpus = list(range(num_gpus))
    else:
        visible_gpus = []
        parts = args.gpus.split(",")
        for p in parts:
            if "-" in p:
                lo, hi = p.split("-")
                lo, hi = int(lo), int(hi)
                assert hi >= lo
                visible_gpus.extend(list(range(lo, hi + 1)))
            else:
                visible_gpus.append(int(p))

    visible_gpus = list(set(visible_gpus))  # keep distinct
    assert len(visible_gpus) > 0

    print("Using GPUs: ", visible_gpus)

    jobs = {}
    for device_id in visible_gpus:
        for i in range(num_workers):
            job_id = f"GPU{device_id:02d}-{i}"
            print(f"[{job_id}] Launching worker-process...")

            if i <= 0:
                kwargs = dict(
                    device_id=device_id,
                    job_id=job_id,
                    worker_id=i,
                    n_gpu=len(visible_gpus),
                    n_worker=num_workers,
                    caption_path=args.caption_path,
                    model_id=args.model_id,
                    save_dir=args.save_dir,
                    size=args.size,
                )
            p = mp.Process(target=run, kwargs=kwargs)
            jobs[job_id] = (p, device_id)
            p.start()

    try:
        while True:
            time.sleep(1)
            for job_id, (job, device_id) in jobs.items():
                if job.is_alive():
                    pass
                else:
                    print(f"[{job_id}] Worker died, cleaning up...")
                    # remove remaining tar file
                    os.system(f"rm -r -f -v {args.save_dir}/images/{job_id}")
                    os.system(f"rm -r -f -v {args.save_dir}/latents/{job_id}")
                    os.system(f"rm -f -v {args.save_dir}/meta_{job_id}.json")
                    os.system(f"rm -f -v {args.save_dir}/meta_{job_id}.json")
                    print(f"[{job_id}] respawning...")
                    p = mp.Process(target=run, kwargs=kwargs)
                    jobs[job_id] = (p, device_id)
                    p.start()

    except KeyboardInterrupt:
        print("Caught KeyboardInterrupt, terminating workers")
        for job_id, (job, device_id) in jobs.items():
            job.terminate()
        for job_id, (job, device_id) in jobs.items():
            print(f"[{job_id}] waiting for exit...")
            job.join()
        print("done.")

if __name__ == "__main__":
    main()
Superviro commented 2 months ago

@Meshwa428 Thanks! And I'm using GPUs from my school lab

Meshwa428 commented 2 months ago

Damnnnn, your school must be rich 🤑. Thanks for info though