NVIDIA / Stable-Diffusion-WebUI-TensorRT

TensorRT Extension for Stable Diffusion Web UI
MIT License
1.92k stars 146 forks source link

Very temporary fix to use LoRA with fp8 weight enabled #209

Open wangziyao318 opened 11 months ago

wangziyao318 commented 11 months ago

This is only for those who want to lower vram usage a bit (to be able to use larger batch size) when using tensorRT with sd webui, at a cost of accuracy.

As far as I can tell, the fp8 optimization (currently available in sd webui dev branch, under Settings/Optimizations) would slightly reduce vram usage when using with tensorRT (from 10.9G to 9.6G to train certain SDXL, compared with from 9.7G to 6.8G without tensorRT), because the tensorRT side still stores data in fp16. The vram usage would decrease further if tensorRT has option to store data in fp8 as well.

LoRA can't be converted to tensorrt under fp8 due to dtype cast issue. Here's a very temporarily and dirty fix to get it work. (in dev branch)

In model_helper.py, line 178

wt = wt.cpu().detach().half().numpy().astype(np.float16)

In exporter.py, line 80 and 82

wt_hash = hash(wt.cpu().detach().half().numpy().astype(np.float16).data.tobytes())

delta = wt.half() - torch.tensor(onnx_data_mapping[initializer_name]).to(wt.device)

The idea is to add .half() to convert tensor dtype fp8 to fp16 to do calculation with other fp16 values. Also notice that cache fp16 weight for LoRA in Settings/Optimizations doesn't work in this fix, and therefore you need to apply more weight to the fp8 LoRA you used to achieve the same effect with LoRA in fp16.

By the way, if you check out sd webui dev branch which uses cu121, you can change to 9.0.1.post12.dev4 or the newer 9.2.0.post12.dev5 for cuda 12. (9.1.0.post12.dev4 building wheel failed in my pc, so I don't suggest it) Ensure to modify install.py to update the version number. (tensorRT still work even if you don't change)

Vinzelles commented 3 weeks ago

你好,根据你的方法我成功解决了LoRA导出报错的问题,非常感谢。 但是,转换生成的.lora文件大小只有1kb,这是正常的吗?另外转换后的LoRA应该如何使用?