cloneofsimo / lora

Using Low-rank adaptation to quickly fine-tune diffusion models.
https://arxiv.org/abs/2106.09685
Apache License 2.0
7.03k stars 480 forks source link

qq about code re porting to flax #35

Closed krahnikblis closed 1 year ago

krahnikblis commented 1 year ago

dude this is awesome! I've been messing with textual inversion for a while but it's not as precise as i want and this looks like the better way! ok so i'd like to help extend this to the flax method which runs way faster than torch, on TPUs and even GPU, but since I'm not familiar with the dreambooth and automatic 111 codes, can you point me to the parts in the training script that you modified? or i guess i can just try to diff the repos... which was the starting one you forked from? any gotchas to watch for, or anyone already started on this front? also i noticed the pytorch checkpoints have different weight/layer names, hoping anyone reading can point to how we can map across...

krahnikblis commented 1 year ago

ok i'ma derp, i see the link to the original training script. happy to discuss flaxy stuff here but otherwise i have enough to work from...

yasyf commented 1 year ago

@krahnikblis not sure if you made progress on this, but I have an almost-done implementation at https://github.com/huggingface/diffusers/pull/1894 that could use some 👀

krahnikblis commented 1 year ago

@yasyf i don't know why i don't see alerts, sorry i missed this. yes, i have built a complete Flax workflow that makes a lora model which acts on a model's parameters (so, complete departure from this repo's method of replacing layers in other models). it works very well for inference, and i can extract and use loras of fine-tune models (and built a workflow for that too, so can grab big 10GB fine-tune models from HF site, lora them, save the 10MB loras and they work great). and, also built a combiner module, with many different methods to test how best to combine multiple fine-tune extracted loras back into one. basically, full success of what i was asking about in the "fine tune model extraction" thread. it's quite fun.

BUT, all my attempts at training are hitting a brick wall, can't get the thing to compile. and it's not like college where you can ask the professor, "hey how do i do this thing in jax?" (i've tried asking GPTchat but not too helpful)