Closed garg-aayush closed 12 months ago
Thank you for the nice issue! @williamberman or @sayakpaul can you have a look here?
It might be because of the dataloader.
I would suggest swapping out datasets
with a library that supports image decoding better and faster (webdataset
for example). Here's an example: https://github.com/huggingface/diffusers/blob/controlnet_webdatasets/examples/controlnet/train_controlnet_webdatasets.py.
Some additional questions:
nccl
, in my experiences, does a good job in most cases. In your case, which communication backend is being used? Hi @sayakpaul
datasets vs webdataset:
The communication protocol also matters a bit. nccl, in my experiences, does a good job in most cases. In your case, which communication backend is being used?:
accelerator.state
after the initiliazation.Distributed environment: MULTI_GPU Backend: nccl
Num processes: 2
Process index: 0
Local process index: 0
Device: cuda:0
Mixed precision type: bf16
Mixed precision type: bf16
What happens with FP16? Are results anyway affected if you used xformers
for example?
How to check for it?
It should be written in the specifications of the multi-GPU machine you're using. If possible, maybe ask the sys admin.
I did some batch loading test to check the dataloading time for both cases. These are the stats and they look simiar. Maybe its not the dataloader issue. | GPUs | precision | batch_per_gpu | effective_batch | avg. batch load time |
---|---|---|---|---|---|
1 | fp16 | 1 | 1 | ~16 ms | |
1 | bf16 | 1 | 1 | ~16 ms | |
2 | fp16 | 1 | 2 | ~18-19 ms | |
2 | bf16 | 1 | 2 | ~18-19 ms |
What happens with FP16? Are results anyway affected if you used xformers for example?
I am getting the similar it/s and total runtime for fp16
as for bf16
. The total runtime for 2 GPUs is still higher than the 1 GPU runtime. Further, the runtime remains same with xformers.
I welcome you to refer to https://github.com/huggingface/diffusers/blob/main/examples/research_projects/controlnet/train_controlnet_webdataset.py to see how we leverage webdataset
to squeeze the best performance.
Also cc @patil-suraj in case he has something more suggest.
Ccing @muellerzr (for accelerate
-honed distributed training) just for visibility.
Sure, I will try webdataset
next. Thanks :)
there's no way should the dataloader be responsible, in my opinion, for a 100% slowdown in performance.
you might want to try the pin_memory()
feature so that the latents from the VAE end up pinned in system memory, which can substantially improve GPU transfer times.
kernel 5.4 might be at fault here. have you any opportunity to try 5.15 or newer?
after testing this a bunch this weekend, i have concluded that the data loader really does need to be done in parallel with multiple GPUs. i had a 5x A100-80G which was using 226 seconds on a single iteration (bs=15*grad_steps=4) when a single GPU ran at 26 seconds per iteration.
added a bunch of logging and determined that just loading the huge images from disk was enough to hurt it, but if your dataset is web-based, the problem is 5-fold.
using concurrent.futures to fetch things in parallel also helps, if doing it manually or with a custom sampler. after all of the optimisations i've made to my trainer now including a metadata tree of image sizes and crop coordinates, i no longer open the original png during training at all. merely just the VAE latent cache object.
that means i'm now at 26 seconds per iteration on the same config for one, or eight GPUs.
@bghira any chance you have a code sample? Would love to learn and see, nice work :)
@muellerzr it's like a squid's tendril, extending through my trainer's source code. i'm sure there are more Pythonic ways, and things that ML researchers love to code, that would make this feel "icky"
for instance, i have a static class called StateTracker that i have some convenience methods on, for listing filesystem contents. this is because our training data is stored in a S3 bucket, and listing the contents across many processes, painfully extends the time to resumption.
the StateTracker's filesystem listing mechanism ends up storing the result on the disk, so that the ChildProcess might pick it up. i do this using:
def split_buckets_between_processes(self, gradient_accumulation_steps=1):
"""
Splits the contents of each bucket in aspect_ratio_bucket_indices between the available processes.
"""
new_aspect_ratio_bucket_indices = {}
total_images = sum(
[len(bucket) for bucket in self.aspect_ratio_bucket_indices.values()]
)
logger.debug(f"Count of items before split: {total_images}")
# Determine the effective batch size for all processes considering gradient accumulation
num_processes = self.accelerator.num_processes
effective_batch_size = (
self.batch_size * num_processes * gradient_accumulation_steps
)
for bucket, images in self.aspect_ratio_bucket_indices.items():
# Trim the list to a length that's divisible by the effective batch size
num_batches = len(images) // effective_batch_size
trimmed_images = images[: num_batches * effective_batch_size]
with self.accelerator.split_between_processes(
trimmed_images, apply_padding=False
) as images_split:
# Now images_split contains only the part of the images list that this process should handle
new_aspect_ratio_bucket_indices[bucket] = images_split
# Replace the original aspect_ratio_bucket_indices with the new one containing only this process's share
self.aspect_ratio_bucket_indices = new_aspect_ratio_bucket_indices
logger.debug(
f"Count of items after split: {sum([len(bucket) for bucket in self.aspect_ratio_bucket_indices.values()])}"
)
to split each aspect bucket by the number of GPUs, batch size, and gradient steps.
for VAE latent caching:
def encode_images(self, images, filepaths, load_from_cache=True):
"""
Encode a batch of input images. Images must be the same dimension.
If load_from_cache=True, we read from the VAE cache rather than encode.
If load_from_cache=True, we will throw an exception if the entry is not found.
"""
batch_size = len(images)
if batch_size != len(filepaths):
raise ValueError("Mismatch between number of images and filepaths.")
# Generates a hash.
full_filenames = [
self.generate_vae_cache_filename(filepath)[0] for filepath in filepaths
]
# Check cache for each image and filter out already cached ones
uncached_images = []
uncached_image_indices = [
i
for i, filename in enumerate(full_filenames)
if not self.data_backend.exists(filename)
]
if len(uncached_image_indices) > 0 and load_from_cache:
# We wanted only uncached images. Something went wrong.
raise Exception(
f"Some images were not correctly cached during the VAE Cache operations. Ensure --skip_file_discovery=vae is not set."
)
elif not load_from_cache:
uncached_images = [images[i] for i in uncached_image_indices]
if load_from_cache:
# If all images are cached, simply load them
latents = [self.load_from_cache(filename) for filename in full_filenames]
# The rest of the function, which then actually encodes the input during pre-processing.
For the aspect bucketing manager itself, we scan at startup for any files that exist in the data pool, but do not show up in the metadata document:
def save_image_metadata(self):
"""Save image metadata to a JSON file."""
self.data_backend.write(self.metadata_file, json.dumps(self.image_metadata))
def scan_for_metadata(self):
"""
Update the metadata without modifying the bucket indices.
"""
logger.info(f"Loading metadata from {self.metadata_file}")
self.load_image_metadata()
logger.debug(
f"A subset of the available metadata: {list(self.image_metadata.keys())[:5]}"
)
logger.info("Discovering new images for metadata scan...")
new_files = self._discover_new_files(for_metadata=True)
if not new_files:
logger.info("No new files discovered. Exiting.")
return
existing_files_set = {
existing_file for existing_file in self.image_metadata.keys()
}
num_cpus = 8 # Using a fixed number for better control and predictability
files_split = np.array_split(new_files, num_cpus)
metadata_updates_queue = Queue()
tqdm_queue = Queue()
workers = [
Process(
target=self._bucket_worker,
args=(
tqdm_queue,
file_shard,
None, # Passing None to indicate we don't want to update the buckets
metadata_updates_queue,
existing_files_set,
self.data_backend,
),
)
for file_shard in files_split
]
for worker in workers:
worker.start()
with tqdm(desc="Scanning metadata for images", total=len(new_files)) as pbar:
while any(worker.is_alive() for worker in workers):
while not tqdm_queue.empty():
pbar.update(tqdm_queue.get())
# Only update the metadata
while not metadata_updates_queue.empty():
metadata_update = metadata_updates_queue.get()
for filepath, meta in metadata_update.items():
self.set_metadata_by_filepath(
filepath=filepath, metadata=meta, update_json=False
)
for worker in workers:
worker.join()
self._save_cache()
self.save_image_metadata()
logger.info("Completed metadata update.")
the disk I/O happens through a pluggable data backend, each one implementing an abstract class with the same methods so I can use local disk or AWS S3 to do the same operations.
the sampler is very complicated. it has no optimisations relevant to GPU training time, other than to ensure correctness. the important thing is that it ensures we receive a full batch of samples with identical dimensions. currently, just one bucket is used, with a set pixel area (eg. 1 megapixel) but future work intends on extending that to support training on multiple pixel area buckets in a single training session. so that one can use .25, .5, and 1 megapixel images randomly.
the dataset class iterator:
def __getitem__(self, image_tuple):
output_data = []
for sample in image_tuple:
image_path = sample["image_path"]
logger.debug(f"Running __getitem__ for {image_path} inside Dataloader.")
image_metadata = self.bucket_manager.get_metadata_by_filepath(image_path)
image_metadata["image_path"] = image_path
if (
image_metadata["original_size"] is None
or image_metadata["target_size"] is None
):
raise Exception(
f"Metadata was unavailable for image: {image_path}. Ensure --skip_file_discovery=metadata is not set."
f" Metadata: {self.bucket_manager.get_metadata_by_filepath(image_path)}"
)
if self.print_names:
logger.info(f"Dataset is now using image: {image_path}")
# Use the magic prompt handler to retrieve the captions.
image_metadata["instance_prompt_text"] = PromptHandler.magic_prompt(
data_backend=self.data_backend,
image_path=image_path,
caption_strategy=self.caption_strategy,
use_captions=self.use_captions,
prepend_instance_prompt=self.prepend_instance_prompt,
)
output_data.append(image_metadata)
return output_data
at collate time, the method looks like:
def collate_fn(batch):
examples = batch[0]
batch_luminance = [example["luminance"] for example in examples]
# average it
batch_luminance = sum(batch_luminance) / len(batch_luminance)
filepaths = extract_filepaths(examples)
latent_batch = compute_latents(filepaths)
check_latent_shapes(latent_batch, filepaths)
# Extract the captions from the examples.
captions = [example["instance_prompt_text"] for example in examples]
prompt_embeds_all, add_text_embeds_all = compute_prompt_embeddings(captions)
batch_time_ids = gather_conditional_size_features(
examples, latent_batch, StateTracker.get_weight_dtype()
)
return {
"latent_batch": latent_batch,
"prompt_embeds": prompt_embeds_all,
"add_text_embeds": add_text_embeds_all,
"batch_time_ids": batch_time_ids,
"batch_luminance": batch_luminance,
}
the key components here to the speed-up are the use of precomputed VAE latents that are retrieved from disk. even better would be to implement pre-fetch with a worker thread in the background that is set up to ensure a number of batches remain accessible ahead of time, at all times. but the complexity of that is a turn-off, even if it would speed up I/O further.
you'd think retrieving the VAE cache entries would be painful, but threading it works like a charm:
def fetch_latent(fp):
"""Worker method to fetch latent for a single image."""
debug_log(" -> pull latents from cache")
latent = StateTracker.get_vaecache().encode_image(None, fp)
# Move to CPU and pin memory if it's not on the GPU
debug_log(" -> push latents to GPU via pinned memory")
latent = latent.to("cpu").pin_memory()
return latent
def compute_latents(filepaths):
# Use a thread pool to fetch latents concurrently
with concurrent.futures.ThreadPoolExecutor() as executor:
latents = list(executor.map(fetch_latent, filepaths))
# Validate shapes
test_shape = latents[0].shape
for idx, latent in enumerate(latents):
if latent.shape != test_shape:
raise ValueError(
f"File {filepaths[idx]} latent shape mismatch: {latent.shape} != {test_shape}"
)
debug_log(" -> stacking latents")
return torch.stack(latents)
no slowdown from training using S3 backend now, or, 14 GPUs.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
I ran into the same problem finetuning the SDXL model.
I met the same problem when funetunning the StableDiffusion
These scripts are not optimized to give you the best throughput when using multiple GPUs. Just as a caution.
Describe the bug
I am attempting to train SD1.5 (model) for text-to-image generation using the train_text_to_image.py script on a custom dataset. Given the large size of the dataset, and having access to a server with 2 A100 GPUs, I initiated multi-GPU training as instructed in the README file.
Unexpectedly, the multi-GPU training operates significantly slower and does not deliver the anticipated speed-up. For instance, I conducted a test where a single GPU training was run with
batch_size=1
for15000 steps
, and a two-GPU training was run withbatch_size=2
for7500 steps
- thus equating the total number of examples seen during training in both scenarios. I anticipated that the two-GPU training would be at least1.5/2X
faster, yet it ran slower than the single-GPU training.Below is a table summarizing the observed performance:
I am seeking assistance to understand what might be going wrong, and how to rectify this issue to achieve the expected performance improvement with multi-GPU training.
Reproduction
See attached requirements.txt to to see the env packages
Single-GPU Training
Two GPU training
Training logs
Single GPU Training:
2.1
steps per second (2.1it/s
)Two GPU Training:
0.84
steps per second (1.19s/it
)Logs
No response
System Info
diffusers
version: 0.22.0.dev0Who can help?
@sayakpaul @yiyixuxu @williamberman @patrickvonplaten