Open ikarsokolov opened 2 months ago
The correct way to handle this is a fairly subtle problem and might depend on the state of the Stable Diffusion ecosystem at large. It's not necessarily as simple as upcasting the float8 weights into the lora dtype, or vice versa. Sorry there's not a fast fix here.
Autocast doesn't seem to work in my experiments, sadly (wrapped it in an autocast context for the train_dtype which was float16). We'd need to explicitly upcast the fp8 tensors. I'm not sure how to guarantee that's a safe operation that reflects unit scaling. Does Torch even have a standard unit scaling API? Is this implemented in an ad-hoc way by Nvidia Transformer Engine and others?
Does Torch even have a standard unit scaling API? Is this implemented in an ad-hoc way by Nvidia Transformer Engine and others?
Looks like there's no standard unit scaling. It would be nice if bitsandbytes supported FP8, then everyone would just use it as standard! We're going to have to arbitrarily decide that FP8 weights need to be unscaled to fix this bug. I think anything present in the community that is actually FP8 is unscaled (to my knowledge), so this shouldn't be a big restriction. If we get future checkpoints in FP8 that were trained with Transformer Engine or something else, we'll need to revisit that. I'm hopeful that everyone just uses bitsandbytes for everything.
@ikarsokolov There is now a beta branch with proper support for fp8 Lora training, the branch name is fp8
please give it a try if you know how to use branches.
What happened?
When I activate Decomposed Weights (DoRA) training in "Lora" tab and have base SDXL model loaded as float8 in "model" tab the training process fails to start with
RuntimeError: Promotion for Float8 Types is not supported, attempted to promote Float8_e4m3fn and Half.
If Dora toggle is deactivated training starts as usual.
What did you expect would happen?
DoRA training working.
Relevant log output
Output of
pip freeze
No response