Closed yuanzhi-zhu closed 2 months ago
And once I upgrade the torch to, saying 2.2.0 with cu121, I got the same error as in https://github.com/tianweiy/DMD2/issues/41
I think it requires 2.0.1 torch version. which unfortunately doesn't seem to have a build with cuda 12.x.
The deeper reason for the error is probably due to some mysterious implementation of FSDP in accelerate. To solve this, we basically need to use a raw FSDP wrapper https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel instead of relying on accelerator.prepare.
some code snippets look like the following
def fsdp_auto_wrap_policy(model, transformer_layer_name):
import functools
from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
def lambda_policy_fn(module):
if (
len(list(module.named_children())) == 0
and getattr(module, "weight", None) is not None
and module.weight.requires_grad
):
return True
return False
lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
transformer_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=(
transformer_layer_name,
),
)
auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
return auto_wrap_policy
def fsdp_wrap(module, transformer_layer_cls, sharding_strategy="hybrid", mixed_precision=False):
if mixed_precision:
mixed_precision_policy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
)
else:
mixed_precision_policy = None
auto_wrap_policy = fsdp_auto_wrap_policy(module, transformer_layer_cls)
if sharding_strategy == "full":
sharding_strategy = ShardingStrategy.FULL_SHARD
elif sharding_strategy == "grad":
sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
elif sharding_strategy == "hybrid":
sharding_strategy = ShardingStrategy.HYBRID_SHARD
# Might improve inter-node all-reduce according to pytorch doc:
# https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.ShardingStrategy
os.environ["NCCL_CROSS_NIC"] = "1"
elif sharding_strategy == "hybrid_zero2":
sharding_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2
os.environ["NCCL_CROSS_NIC"] = "1"
elif sharding_strategy == "no_shard":
sharding_strategy = ShardingStrategy.NO_SHARD
else:
msg = f"Sharding strategy {sharding_strategy} must be one of 'full', 'grad', 'hybrid'"
raise NotImplementedError(msg)
module = FSDP(
module,
sharding_strategy=sharding_strategy,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mixed_precision_policy,
device_id=torch.cuda.current_device(),
sync_module_states=True,
)
return module
the auto_wrap policy needs to be adapted to SDXL (like using size based wrapping instead of the transformer one here)
I don't have time to test this recently but this works well in other project's codebase. and once this is fixed, it should support any torch / cuda version
Hi @tianweiy , I removed FSDP and made it to train with bs=2 (with bf16) on a single 80G GPU, thanks.
However, when I tested it on two cards, I can at most put bs=1 for each GPU, is this expected?
besides, it would be even better that you can share the training loss, etc.
The task is SDXL 4 step distillation
However, when I tested it on two cards, I can at most put bs=1 for each GPU, is this expected?
DDP or so will add multiple GB extra overhead so it is possible. once there is no fsdp, I think you can enable gradient checkpointing. this might save some memory (might need to add the gradient checkpointing when computing gan loss too) ? Full bf16 might or might not have precision issues. Interested to see how it goes for you.
One more trick you could do is to offload the real unet to cpu() after computing dmd loss. this will save quite a few GB I think.
Basically, load to gpu at https://github.com/tianweiy/DMD2/blob/0f8a481716539af7b2795740c9763a7d0d05b83b/main/sd_guidance.py#L178
and offload to cpu at https://github.com/tianweiy/DMD2/blob/0f8a481716539af7b2795740c9763a7d0d05b83b/main/sd_guidance.py#L255
It will be slow but anyway the whole codebase is not very fast at the moment lol. I think newer PyTorch version seems to be faster
besides, it would be even better that you can share the training loss, etc.
see logs.pdf
By the way, feel free to let me know if you have any questions about how useful is one specific loss / hyper parameter / etc..
Hi @tianweiy ,thanks a lot for your help.
now I can train DMD on multiple GPUs with bs=2 for each card, but the results are not good for now. according to your logs, after 2k iters we should see some nice generated pics.
however the dm loss is not dropping, and the generated image std is not increasing...
have you tried bf16 training on SDXL before?
I didn't. Actually could you send me an email? We can probably set up a call to figure out the issues. Thanks
On Sep 6, 2024 3:24 AM, Yuanzhi Zhu @.***> wrote:
Hi @tianweiyhttps://github.com/tianweiy ,thanks a lot for your help.
now I can train DMD on multiple GPUs with bs=2 for each card, but the results are not good for now. according to your logs, after 2k iters we should see some nice generated pics.
however the dm loss is not dropping, and the generated image std is not increasing...
have you tried bf16 training on SDXL before?
— Reply to this email directly, view it on GitHubhttps://github.com/tianweiy/DMD2/issues/47#issuecomment-2333415835, or unsubscribehttps://github.com/notifications/unsubscribe-auth/AJFWY3QB7CF3SGYMEPIU7LLZVFKAVAVCNFSM6AAAAABNSLF7GKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDGMZTGQYTKOBTGU. You are receiving this because you were mentioned.Message ID: @.***>
I create conda env followed README and got this error when training SDXL.
It seems that the torch version installed is with cu11 but my GPU has cuda 12.2 (I checked from nvidia-smi).
However, the previous issue https://github.com/tianweiy/DMD2/issues/41 indicate the only working env is the one follows README?
Massive thanks for your work