ToTheBeginning / PuLID

[NeurIPS 2024] Official code for PuLID: Pure and Lightning ID Customization via Contrastive Alignment
Apache License 2.0
2.55k stars 178 forks source link

Flux model is being fully created before bfloat16 cast #67

Open PRPA1984 opened 1 month ago

PRPA1984 commented 1 month ago

When loading Flux model, the entire model is being created before the cast.

model = Flux(configs[name].params).to(torch.bfloat16)

The issue here is that a lot of RAM is being drained during the model creation (because submodels are being initialized with random parameters). I fixed this in the meanwhile by casting every submodule during its creation

denred0 commented 1 month ago

When loading Flux model, the entire model is being created before the cast.

model = Flux(configs[name].params).to(torch.bfloat16)

The issue here is that a lot of RAM is being drained during the model creation (because submodels are being initialized with random parameters). I fixed this in the meanwhile by casting every submodule during its creation

Everyone with the same problem. Need to add .to(torch.bfloat16) in flux/model.py here:

 self.double_blocks = nn.ModuleList(
            [
                DoubleStreamBlock(
                    self.hidden_size,
                    self.num_heads,
                    mlp_ratio=params.mlp_ratio,
                    qkv_bias=params.qkv_bias,
                ).to(torch.bfloat16)

                for _ in range(params.depth)
            ]
        )

and here

self.single_blocks = nn.ModuleList(
            [
                SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio).to(torch.bfloat16)

                for _ in range(params.depth_single_blocks)
            ]
        )