Open PRPA1984 opened 1 month ago
When loading Flux model, the entire model is being created before the cast.
model = Flux(configs[name].params).to(torch.bfloat16)
The issue here is that a lot of RAM is being drained during the model creation (because submodels are being initialized with random parameters). I fixed this in the meanwhile by casting every submodule during its creation
Everyone with the same problem.
Need to add .to(torch.bfloat16)
in flux/model.py here:
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
).to(torch.bfloat16)
for _ in range(params.depth)
]
)
and here
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio).to(torch.bfloat16)
for _ in range(params.depth_single_blocks)
]
)
When loading Flux model, the entire model is being created before the cast.
The issue here is that a lot of RAM is being drained during the model creation (because submodels are being initialized with random parameters). I fixed this in the meanwhile by casting every submodule during its creation