tianweiy / DMD2

Other
422 stars 24 forks source link

train a 4step SDXL got CUDA error: no kernel image is available for execution on the device #47

Closed yuanzhi-zhu closed 1 week ago

yuanzhi-zhu commented 1 week ago

I create conda env followed README and got this error when training SDXL.

It seems that the torch version installed is with cu11 but my GPU has cuda 12.2 (I checked from nvidia-smi).

However, the previous issue https://github.com/tianweiy/DMD2/issues/41 indicate the only working env is the one follows README?

Massive thanks for your work

yuanzhi-zhu commented 1 week ago

And once I upgrade the torch to, saying 2.2.0 with cu121, I got the same error as in https://github.com/tianweiy/DMD2/issues/41

tianweiy commented 1 week ago

I think it requires 2.0.1 torch version. which unfortunately doesn't seem to have a build with cuda 12.x.

The deeper reason for the error is probably due to some mysterious implementation of FSDP in accelerate. To solve this, we basically need to use a raw FSDP wrapper https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel instead of relying on accelerator.prepare.

some code snippets look like the following

def fsdp_auto_wrap_policy(model, transformer_layer_name):
    import functools

    from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy

    def lambda_policy_fn(module):
        if (
            len(list(module.named_children())) == 0
            and getattr(module, "weight", None) is not None
            and module.weight.requires_grad
        ):
            return True
        return False

    lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
    transformer_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls=(
            transformer_layer_name,
        ),
    )

    auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
    return auto_wrap_policy

def fsdp_wrap(module, transformer_layer_cls, sharding_strategy="hybrid", mixed_precision=False):
    if mixed_precision:
        mixed_precision_policy = MixedPrecision(
            param_dtype=torch.bfloat16,
            reduce_dtype=torch.bfloat16,
            buffer_dtype=torch.bfloat16,
        )
    else:
        mixed_precision_policy = None 

    auto_wrap_policy = fsdp_auto_wrap_policy(module, transformer_layer_cls)

    if sharding_strategy == "full":
        sharding_strategy = ShardingStrategy.FULL_SHARD
    elif sharding_strategy == "grad":
        sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
    elif sharding_strategy == "hybrid":
        sharding_strategy = ShardingStrategy.HYBRID_SHARD
        # Might improve inter-node all-reduce according to pytorch doc:
        #   https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.ShardingStrategy
        os.environ["NCCL_CROSS_NIC"] = "1"
    elif sharding_strategy == "hybrid_zero2":
        sharding_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2
        os.environ["NCCL_CROSS_NIC"] = "1"
    elif sharding_strategy == "no_shard":
        sharding_strategy = ShardingStrategy.NO_SHARD
    else:
        msg = f"Sharding strategy {sharding_strategy} must be one of 'full', 'grad', 'hybrid'"
        raise NotImplementedError(msg)
    module = FSDP(
        module,
        sharding_strategy=sharding_strategy,  
        auto_wrap_policy=auto_wrap_policy,
        mixed_precision=mixed_precision_policy,
        device_id=torch.cuda.current_device(),
        sync_module_states=True,  
    )
    return module    

the auto_wrap policy needs to be adapted to SDXL (like using size based wrapping instead of the transformer one here)

I don't have time to test this recently but this works well in other project's codebase. and once this is fixed, it should support any torch / cuda version

yuanzhi-zhu commented 1 week ago

Hi @tianweiy , I removed FSDP and made it to train with bs=2 (with bf16) on a single 80G GPU, thanks.

However, when I tested it on two cards, I can at most put bs=1 for each GPU, is this expected?

besides, it would be even better that you can share the training loss, etc.

The task is SDXL 4 step distillation

tianweiy commented 1 week ago

However, when I tested it on two cards, I can at most put bs=1 for each GPU, is this expected?

DDP or so will add multiple GB extra overhead so it is possible. once there is no fsdp, I think you can enable gradient checkpointing. this might save some memory (might need to add the gradient checkpointing when computing gan loss too) ? Full bf16 might or might not have precision issues. Interested to see how it goes for you.

One more trick you could do is to offload the real unet to cpu() after computing dmd loss. this will save quite a few GB I think.

Basically, load to gpu at https://github.com/tianweiy/DMD2/blob/0f8a481716539af7b2795740c9763a7d0d05b83b/main/sd_guidance.py#L178

and offload to cpu at https://github.com/tianweiy/DMD2/blob/0f8a481716539af7b2795740c9763a7d0d05b83b/main/sd_guidance.py#L255

It will be slow but anyway the whole codebase is not very fast at the moment lol. I think newer PyTorch version seems to be faster

besides, it would be even better that you can share the training loss, etc.

see logs.pdf

By the way, feel free to let me know if you have any questions about how useful is one specific loss / hyper parameter / etc..

yuanzhi-zhu commented 1 week ago

Hi @tianweiy ,thanks a lot for your help.

now I can train DMD on multiple GPUs with bs=2 for each card, but the results are not good for now. according to your logs, after 2k iters we should see some nice generated pics.

however the dm loss is not dropping, and the generated image std is not increasing...

have you tried bf16 training on SDXL before?

tianweiy commented 1 week ago

I didn't. Actually could you send me an email? We can probably set up a call to figure out the issues. Thanks

On Sep 6, 2024 3:24 AM, Yuanzhi Zhu @.***> wrote:

Hi @tianweiyhttps://github.com/tianweiy ,thanks a lot for your help.

now I can train DMD on multiple GPUs with bs=2 for each card, but the results are not good for now. according to your logs, after 2k iters we should see some nice generated pics.

however the dm loss is not dropping, and the generated image std is not increasing...

have you tried bf16 training on SDXL before?

— Reply to this email directly, view it on GitHubhttps://github.com/tianweiy/DMD2/issues/47#issuecomment-2333415835, or unsubscribehttps://github.com/notifications/unsubscribe-auth/AJFWY3QB7CF3SGYMEPIU7LLZVFKAVAVCNFSM6AAAAABNSLF7GKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDGMZTGQYTKOBTGU. You are receiving this because you were mentioned.Message ID: @.***>