kohya-ss / sd-scripts

Apache License 2.0
5k stars 839 forks source link

Adding Validation Loss to detect Overtraining #193

Open Naegles opened 1 year ago

Naegles commented 1 year ago

I saw that this repo has added the ability to use a validation loss to help figure out the optimal amount of training. Might be an interesting addition.

https://github.com/victorchall/EveryDream2trainer/blob/main/doc/VALIDATION.md

shirayu commented 1 year ago

Good idea. I believe that it is not enough to simply observe loss/average, since the weights are continually being updated.

kohya-ss commented 1 year ago

Thank you for the suggestion! I was wondering if validation loss is a valid metric since SD loss is so fluctuating, but the document in EveryDream repo is very interesting.

I will consider to implement the validation loss, but it will take some time...

uwidev commented 1 year ago

I think this would be a good means to determine learning rates. There's a lot of speculation as to the proper rate, but with validation, we should be able to prove appropriate learning rates. The question then would be if that value would be universally good when training something specific (e.g. style), or good for everything. I think there's a lot of potential to be gained with this.

slashedstar commented 1 year ago

So... is this still being considered or was it scraped? @kohya-ss

kohya-ss commented 1 year ago

Sorry it has taken so long. This is on the task list, but I have not been able to get to it due to priority issues with other tasks.

It also needs to enhance the dataset classes, it will take a time...

AMorporkian commented 1 year ago

I have implemented this in a fork I originally created to implement the new noise scheduling functions. I won't create a PR because the massive refactoring I have done is largely experimental, but you can check it out here. I'm doing a run with wandb right now to test LoRA settings.

https://github.com/AMorporkian/kohya_ss/tree/hypertune

Sorry for the absolutely terrible commits, I was just messing around and the messages are not at all descriptive. If there isn't any movement on this by next week, I'll take some time to actually make some proper code that would integrate better.

kmacmcfarlane commented 1 year ago

This seems like a really cool idea. I have to make a lot of different training runs to figure out what the optimum settings are for a given training set. Even doing things in epochs, number of training images seems to still affect the results somewhat.

rockerBOO commented 11 months ago

I have implemented this in a fork I originally created to implement the new noise scheduling functions. I won't create a PR because the massive refactoring I have done is largely experimental, but you can check it out here. I'm doing a run with wandb right now to test LoRA settings.

https://github.com/AMorporkian/kohya_ss/tree/hypertune

I borrowed some of your ideas (random_split and collate updates) and put a copy/paste validation loss together. Helped to not have to create a custom dataset to get this ball rolling.

If anyone else is interested in testing, see #914 . Expanding into validation datasets would be the next step but this works well with minimal changes.

jacquesfeng123 commented 4 months ago

any updates on this :D