Nerogar / OneTrainer

OneTrainer is a one-stop solution for all your stable diffusion training needs.
GNU Affero General Public License v3.0
1.76k stars 148 forks source link

[Bug]: Setting "train" to false for TEs does not freeze TEs during SDXL finetuning #446

Open orcinus opened 2 months ago

orcinus commented 2 months ago

What happened?

Configuring TEs as follows:

"text_encoder": {
    "train": false, 
    "learning_rate": 2e-8, 
    "layer_skip": 0, 
    "weight_dtype": "FLOAT_32", 
    "stop_training_after": 5000, 
    "stop_training_after_unit": "STEP"
  },
  "text_encoder_2": {
    "train": false, 
    "learning_rate": 1e-8, 
    "layer_skip": 0, 
    "weight_dtype": "FLOAT_32", 
    "stop_training_after": 5000, 
    "stop_training_after_unit": "STEP"
  }, 

... the TEs do not get frozen during finetune, and get trained regardless. Easily verifiable by diffing TEs from original model vs. finetuned, or just hooking up original CLIP vs. trained in ComfyUI and comparing inferences (they'll be different).

Assuming this is my mistake, and i've configured OneTrainer wrong - i.e. i also need to specify the "include": false parameter - why is that the case, and why are there two parameters for this?

What did you expect would happen?

TEs should be frozen and remain unchanged during finetune.

Relevant log output

not applicable

Output of pip freeze

absl-py==2.1.0
accelerate==0.30.1
aiohappyeyeballs==2.3.5
aiohttp==3.10.3
aiosignal==1.3.1
antlr4-python3-runtime==4.9.3
async-timeout==4.0.3
attrs==24.2.0
bitsandbytes==0.43.1
certifi==2024.7.4
charset-normalizer==3.3.2
cloudpickle==3.0.0
coloredlogs==15.0.1
contourpy==1.2.1
customtkinter==5.2.2
cycler==0.12.1
dadaptation==3.2
darkdetect==0.8.0
-e git+https://github.com/huggingface/diffusers.git@dd4b731e68f88f58dfabfb68f28e00ede2bb90ae#egg=diffusers
filelock==3.15.4
flatbuffers==24.3.25
fonttools==4.53.1
frozenlist==1.4.1
fsspec==2024.6.1
ftfy==6.2.3
grpcio==1.65.4
huggingface-hub==0.23.3
humanfriendly==10.0
idna==3.7
importlib_metadata==8.2.0
invisible-watermark==0.2.0
Jinja2==3.1.4
kiwisolver==1.4.5
lightning-utilities==0.11.6
lion-pytorch==0.1.4
Markdown==3.6
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.9.0
mdurl==0.1.2
-e git+https://github.com/Nerogar/mgds.git@d38efdf377a2d52c32aebf7820f10342e16221bf#egg=mgds
mpmath==1.3.0
multidict==6.0.5
networkx==3.3
numpy==1.26.4
nvidia-cublas-cu11==11.11.3.6
nvidia-cuda-cupti-cu11==11.8.87
nvidia-cuda-nvrtc-cu11==11.8.89
nvidia-cuda-runtime-cu11==11.8.89
nvidia-cudnn-cu11==8.7.0.84
nvidia-cufft-cu11==10.9.0.58
nvidia-curand-cu11==10.3.0.86
nvidia-cusolver-cu11==11.4.1.48
nvidia-cusparse-cu11==11.7.5.86
nvidia-nccl-cu11==2.20.5
nvidia-nvtx-cu11==11.8.86
omegaconf==2.3.0
onnxruntime-gpu==1.18.0
open-clip-torch==2.24.0
opencv-python==4.9.0.80
packaging==24.1
pillow==10.3.0
platformdirs==4.2.2
pooch==1.8.1
prodigyopt==1.0
protobuf==4.25.4
psutil==6.0.0
Pygments==2.18.0
pynvml==11.5.0
pyparsing==3.1.2
python-dateutil==2.9.0.post0
pytorch-lightning==2.2.5
pytorch_optimizer==3.0.2
PyWavelets==1.7.0
PyYAML==6.0.1
regex==2024.7.24
requests==2.32.3
rich==13.7.1
safetensors==0.4.3
scalene==1.5.41
schedulefree==1.2.5
scipy==1.13.1
sentencepiece==0.2.0
six==1.16.0
sympy==1.13.2
tensorboard==2.17.0
tensorboard-data-server==0.7.2
timm==1.0.8
tokenizers==0.19.1
torch==2.3.1+cu118
torchmetrics==1.4.1
torchvision==0.18.1+cu118
tqdm==4.66.4
transformers==4.42.3
triton==2.3.1
typing_extensions==4.12.2
urllib3==2.2.2
wcwidth==0.2.13
Werkzeug==3.0.3
xformers==0.0.27+cu118
yarl==1.9.4
zipp==3.20.0
orcinus commented 2 months ago

Will test with "include": false too in a few hours - need to finish current finetune run first.

Nerogar commented 2 months ago

Include should not be changed for sdxl. It's only useful for models where one of the text encoders is optional like SD3. I've never seen anything suggesting disabled text encoder training doesn't work. And you can easily see that in the code by checking if the text encoder is even loaded into vram.

hameerabbasi commented 2 months ago

Is it possible you changed the dtype of the text encoders, and therefore see slight differences in inference?

orcinus commented 2 months ago

@Nerogar yeah, i've checked the code after i posted this

@hameerabbasi entirely possible, just realized that this morning, but haven't had a chance to test completely yet (another test run in progress, will check after)... omitting dtype from TE config still produces the same result, but i still have output_dtype set to float32, so that's likely the cause.

For what its worth, the differences in inference weren't slight, they were pretty significant (pose, style, background, everything) and got larger the farther along into the training i tested, that's why it genuinely seemed like a training difference, rather than just a small change caused by dtype.

orcinus commented 2 months ago

Okay. Fine tuning original HF repo sdxl safetensors.

Output .safetensors still inferences differently with its CLIP vs. original SDXL CLIP.

What am i doing wrong here?

O-J1 commented 3 weeks ago

Okay. Fine tuning original HF repo sdxl safetensors.

  • output_dtype set to FLOAT16
  • train_dtype set to FLOAT32
  • unet weight_dtype set to FLOAT32
  • TE1 and TE2 weight_dtype set explicitly to FLOAT16, train: false, include: false

Output .safetensors still inferences differently with its CLIP vs. original SDXL CLIP.

What am i doing wrong here?

Can you provide some inference examples?

orcinus commented 3 weeks ago

Can you provide some inference examples?

In about a week or two. I'm away on a trip, and despite leaving a VPN tunnel open to my ML gear back home, the friggin Wireshark server died after i left -_-

hameerabbasi commented 3 weeks ago
  • output_dtype set to FLOAT16

This actually might be the problem -- try FLOAT32.

orcinus commented 3 weeks ago

This actually might be the problem -- try FLOAT32.

But the original model is FLOAT16. The objective is to:

O-J1 commented 1 day ago

Can you provide some inference examples?

In about a week or two. I'm away on a trip, and despite leaving a VPN tunnel open to my ML gear back home, the friggin Wireshark server died after i left -_-

@orcinus bump for followup

orcinus commented 1 day ago

Ahh, sorry, got back, and straight into complete chaos at work. Completely dropped the ball on this.

Will make something as soon as possible.