huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.92k stars 966 forks source link

Can't prepare model after Quanto is applied on DistributedDataParallel #3040

Closed bghira closed 1 month ago

bghira commented 2 months ago

System Info

Information

Tasks

Reproduction

import torch, accelerate
from diffusers import FluxTransformer2DModel
from optimum.quanto import quantize, qint8, freeze
weight_dtype = torch.bfloat16

accelerator = accelerate.Accelerator()

bfl_model = 'black-forest-labs/FLUX.1-dev'
transformer = FluxTransformer2DModel.from_pretrained(bfl_model, torch_dtype=torch.bfloat16, subfolder="transformer")

# you might need 'with accelerator.main_process_first()' if your server lacks system mem
print('quantizing')
quantize(transformer, qint8)
print('freezing')
freeze(transformer)

tpacked_noisy_latents = torch.randn(1, 1024, 64,dtype=weight_dtype, device=accelerator.device)
tpooled_projections = torch.randn(1, 768,dtype=weight_dtype, device=accelerator.device)
ttimesteps = torch.randn(1,dtype=weight_dtype, device=accelerator.device)
tguidance = torch.randn(1,dtype=weight_dtype, device=accelerator.device)
tencoder_hidden_states = torch.randn(1, 512, 4096,dtype=weight_dtype, device=accelerator.device)
ttxt_ids = torch.randn(1, 512, 3,dtype=weight_dtype, device=accelerator.device)
timg_ids = torch.randn(1, 4320, 3,dtype=weight_dtype, device=accelerator.device)

#with torch.no_grad():
#    model_pred = transformer(
#        hidden_states=tpacked_noisy_latents,
#        timestep=ttimesteps,
#        guidance=tguidance,
#        pooled_projections=tpooled_projections,
#        encoder_hidden_states=tencoder_hidden_states,
#        txt_ids=ttxt_ids,
#        img_ids=timg_ids,
#        joint_attention_kwargs=None,
#        return_dict=False,
#    )
transformer = accelerator.prepare(transformer)

Shows that there are uninitialised, empty parameters.

Expected behavior

The model should be prepared.

bghira commented 2 months ago

cc @sayakpaul and @muellerzr

github-actions[bot] commented 1 month ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

bghira commented 1 month ago

now seems to work on accelerate v0.34.2 and diffusers v0.30.3