98mxr / GMFSS_union

High Performance GMFSS with RIFE and GAN
MIT License
24 stars 4 forks source link

Training code #1

Closed korakoe closed 1 year ago

korakoe commented 1 year ago

Hello, could you provide the training code, it would be really helpful for training custom models

98mxr commented 1 year ago

now available here. There must be a lot of mistakes in this code, so I won't close issus, if there are any problems, please ask, don't forget to check my modified GAN Loss output.

hiredd commented 1 year ago

now available here. There must be a lot of mistakes in this code, so I won't close issus, if there are any problems, please ask, don't forget to check my modified GAN Loss output.

use RIFE_Fix_GAN_Loss_output.py or RIFE.py to train the model?

98mxr commented 1 year ago

现在在这里可用。这段代码肯定有很多错误,所以我不会关闭issus,如果有问题请追问,别忘了查看我修改后的GAN Loss输出。

使用 RIFE_Fix_GAN_Loss_output.py 或 RIFE.py 来训练模型?

That's it.

hiredd commented 1 year ago

so what's the difference between the RIFE_Fix_GAN_Loss_output.py and RIFE.py

hiredd commented 1 year ago

hi , How to train the model? I have trained the model using your code and atd_12k datasets,but the result is real poor ~

98mxr commented 1 year ago

hi , How to train the model? I have trained the model using your code and atd_12k datasets,but the result is real poor ~

Fix_GAN_Loss_output fixes the loss value of GAN Loss, which is just a tensorboard value, GAN Loss itself is working fine.

The problem of "real poor" I'm afraid we have to take our time, first of all, the network relies on the results of pre-training, train_log four pkl loaded properly, theoretically the first step in training can output a good quality.

hiredd commented 1 year ago

Thanks for your reply, so where you get the four pkl file? which pkl file is trained by yourself? i get nan loss after about 20 epochs。 can you give some advice ? thanks !!

98mxr commented 1 year ago

train_log

I put a train_log folder under the train folder with the pre-trained pkl, but the associated code is mandatory and should report an error if it didn't load successfully.

There are multiple explanations for nan loss, the most general solution for GMFSS is to change line 41 of metricnet.py.

return metric[:, :1], metric[:, 1:2]
to
return torch.clamp(metric[:, :1], -10, 10), torch.clamp(metric[:, 1:2], -10, 10)

Thank you for your interest and please don't give up hope on GMFSS.

hiredd commented 1 year ago

have you some train logs

using this method i can train the model without nan loss , but the result is getting worse after a few epochs, any tricks for train the model? thanks!!

98mxr commented 1 year ago

have you some train logs

using this method i can train the model without nan loss , but the result is getting worse after a few epochs, any tricks for train the model? thanks!!

I haven't saved this log until now, but of course I don't mind running it again if I have to.

I don't see how far the result is "getting worse", but the discriminator does take a few epochs to initialize before it can function. I think at the 10th epoch it should be enough for GMFSS to work properly.

I implemented a larger batch size in the code with gradient accumulation (I only have one V100 here), which is important, I used bs=16*3. Gradient accumulation should work. If possible, use a larger bs if possible.

Translated with www.DeepL.com/Translator (free version)

hiredd commented 1 year ago

thank you ~ , Any other data augment or preprocessing except the method in the train.py file ? I stiil can't get good model as your pretrained.

98mxr commented 1 year ago

thank you ~ , Any other data augment or preprocessing except the method in the train.py file ? I stiil can't get good model as your pretrained.

I decided to try running it again, so please wait for my results.

hiredd commented 1 year ago

thanks a lot ! can you retrain the model just use ATD-12K dataset(the same as my traing config) . then i can get the real reason for my poor trained model. another question, why should i clip the metric as torch.clamp(metric[:, :1], -10, 10), torch.clamp(metric[:, 1:2], -10, 10), and why this can avoid the nan loss? waiting for your results.

98mxr commented 1 year ago

thanks a lot ! can you retrain the model just use ATD-12K dataset(the same as my traing config) . then i can get the real reason for my poor trained model. another question, why should i clip the metric as torch.clamp(metric[:, :1], -10, 10), torch.clamp(metric[:, 1:2], -10, 10), and why this can avoid the nan loss? waiting for your results.

The problem comes from the metric parameter Z required by the softmax-splatting warp. i use a small NN network, called MetricNet, to generate the metric parameter Z. i modified the official implementation of softmax-splatting to use flow as input in MetricNet, but i did not normalize This allows the flow (which can reach a maximum value of 300+) to be fed directly into the net, which results in the network outputting a metric parameter Z with a very large value. If the metric parameter Z is too large, the result at warp time is easily out of the FP32 range causing the network to crash.

I have been experimenting with the next generation of GMFSS after GMFSS_union, and the MetricNet bug was the first issue to be addressed, but I still don't have a good way to maintain performance while keeping the network stable under any conditions.

I'll be experimenting with the ATD-12K part of the train today and tomorrow, and I'll be the first to post any results here.

98mxr commented 1 year ago

@hiredd The training is less than 20 epochs away and so far everything is fine on my end.

hiredd commented 1 year ago

can you upload the model and logs you trained using atd-12k , i train the model using distribution using 4 NVIDIA V100 following RIFE 。and following is my logs: 企业微信截图_1676001602852

98mxr commented 1 year ago

can you upload the model and logs you trained using atd-12k , i train the model using distribution using 4 NVIDIA V100 following RIFE 。and following is my logs: 企业微信截图_1676001602852

Of course, I was going to wait until all the training was done before posting it, but it looks like you have time now.

Here is the training log up to epoch 55, the model and the list of datasets I used this time. https://drive.google.com/file/d/1y_Bo35ZAiD76Yy5BhT5Dxwum0CP3oeCj/view?usp=share_link

But my log screenshot looks similar to yours, is it because GMFSS doesn't use lower LPIPS and other metrics to get better visuals, and that makes you misjudge the training process?

image

Justin62628 commented 1 year ago

中国人之间,不用洋文聊天✋

hiredd commented 1 year ago

Thanks ~ Your LPIPS loss looks small than mine, Do you clip the MetricNet with torch.clamp(metric[:, :1], -10, 10), torch.clamp(metric[:, 1:2], -10, 10) and your code is same as github you upload?

98mxr commented 1 year ago

Thanks ~ Your LPIPS loss looks small than mine, Do you clip the MetricNet with torch.clamp(metric[:, :1], -10, 10), torch.clamp(metric[:, 1:2], -10, 10) and your code is same as github you upload?

I have checked my code and it is exactly the same as on github. This also shows that I did not clip and use train.py (instead of RIFE_Fix_GAN_Loss_output.py). clip will affect the final result, it is just a stopgap measure for training stability. Also the log's GAN Loss curve will be inconsistent with Fix_output's.

98mxr commented 1 year ago

@hiredd 我在google drive重新发布了我的70epoch最终模型和log,并且代码完全遵守我上传在github的版本(没有clip),希望这对你有帮助。

98mxr commented 1 year ago

Now that there is log to help you judge the training process, I will turn off this issus now.

hiredd commented 1 year ago

thank You !!