SkyTNT / anime-segmentation

high-accuracy segmentation for anime character
Apache License 2.0
649 stars 62 forks source link

Training #2

Closed Zarxrax closed 1 year ago

Zarxrax commented 1 year ago

Hi, could you provide more detail about how you put together the dataset and trained the model? I am interested in trying to train it, but am confused about the structure of the dataset.

Zarxrax commented 1 year ago

Nevermind, I had somehow missed the info at the very bottom to download your actual dataset. It's starting to make sense now. I am putting together a new dataset to try to enhance the model even further. Would you be interested in working with me to train it?

Zarxrax commented 1 year ago

@SkyTNT would you be able to explain a bit more about how the dataset works? I have finally gotten around to downloading the full thing and exploring the code.

My understanding is that the dataset is generated on the fly by compositing the BG and FG images with some augmentations. Really quite clever! But, if that is the case, then what is the purpose of the imgs & masks folders? Those masks are of extremely poor quality, so I'm not sure what they are supposed to be for.

SkyTNT commented 1 year ago

@SkyTNT would you be able to explain a bit more about how the dataset works? I have finally gotten around to downloading the full thing and exploring the code.

My understanding is that the dataset is generated on the fly by compositing the BG and FG images with some augmentations. Really quite clever! But, if that is the case, then what is the purpose of the imgs & masks folders? Those masks are of extremely poor quality, so I'm not sure what they are supposed to be for.

The imgs & masks folder contains manually labeled images. But when I try to train using only the generated dataset, the accuracy of the model drops. I guess the generated datasets deviate from the real ones, so I kept them. According to my training, convert to sketch data augmentation seems to be very helpful to improve the accuracy. But my convert to sketch implementation should still need to be improved to make it closer to the real sketch. I also found that compositing and data augmentation at low resolutions can make image detail look unnatural. But doing it at high resolution will be slow. If your cpu is powerful enough, you can increase the resolution of DatasetGenerator, and add RescalePad(image_size) at transform_generator. https://github.com/SkyTNT/anime-segmentation/blob/c98f338ce402a1e7bfccb93b66c9626c741b28c5/data_loader.py#L272-L273

Zarxrax commented 1 year ago

Thanks, that is helpful to know. I will experiment with the augmentations, and I think I can put together an alternate dataset of the manually labeled images.

Currently I have been able to execute the training code, but I am running into a problem with export.py. RuntimeError: Error(s) in loading state_dict for AnimeSegmentation Do you know what might cause this?

SkyTNT commented 1 year ago

Can you provide more detailed information?

Zarxrax commented 1 year ago

Oh, it was my mistake. I didn't realize the training was 2 steps (gt encoder, then the 2nd training), so I did not train long enough to be able to export.

Zarxrax commented 1 year ago

After updating pytorch version to cuda 11.7, now I get an error when trying to train:

File "D:\Downloads\anime-segmentation\train.py", line 113, in training_step loss0, loss = self.net.compute_loss(loss_args) TypeError: ISNetGTEncoder.compute_loss() missing 1 required positional argument: 'targets'

Is it a bug with using cuda 11.7 pytorch? Shall I need to downgrade?

SkyTNT commented 1 year ago

Did you modified ISNetGTEncoder.compute_loss? https://github.com/SkyTNT/anime-segmentation/blob/2373527d745755fbe2987ba146d9326fed8e8881/model/isnet.py#L431-L434

Zarxrax commented 1 year ago

Thank you, it had changed somehow, so I was able to download the file again and it works. I am not sure what happened, but maybe I accidentally copied the file from the original DIS repo. I was looking at the code all day, so I probably made stupid mistakes.

If I want to write the generated dataset images to files on my disk, can you recommend where would be the best place to insert the code for that? I would like to see the images so I can tweak it more easily.

SkyTNT commented 1 year ago

You can insert the code at https://github.com/SkyTNT/anime-segmentation/blob/2373527d745755fbe2987ba146d9326fed8e8881/dataset_generator.py#L311-L312 with argparse.ArgumentParser

Zarxrax commented 1 year ago

A few more questions, if you don't mind.

  1. How many epochs did you have to train the model? Both the gt model and the standard model.
  2. Is it possible to train the gt model separately? I assume I could train using --net isnet_gt, and then next use --net isnet? How would I tell the isnet which gt model to use?
  3. My goal is to train a model that is optimized for real anime video frames, at 1080p resolution. Do you think its better to leave the model --img-size at 1024? Or should I use 1080? (And I assume I would need to change some of the code, such as in dataset_generator.py to accommodate the changed resolution)

Thanks!

SkyTNT commented 1 year ago
  1. Gt model only needs few epochs to train. But standard model needs hundreds or even thousands epoch to train.
  2. Training the gt model is just the first training stage. It will never chage after standard model starts training. So you don't need to train them separately.
  3. 1080 should be ok. And the augmentations in dataset_generator is auto adapted to resolution. But i'm not sure it is right. You can check the output of dataloader by dataloader_test.py ot test.py.