nopperl / marked-lineart-vectorization

Vectorization of clean line-art raster images using an encoder-decoder model.
https://huggingface.co/spaces/nopperl/lineart-vectorizer
Apache License 2.0
3 stars 2 forks source link

Concerns Regarding Training Duration of Vectorization Method #1

Open IoanaVoica20 opened 3 months ago

IoanaVoica20 commented 3 months ago

Hello! I have been experimenting with training a vectorization method using the provided parameters in the 'marked-clean.yaml' configuration file. However, I've noticed that the loss seems to remain constant.

Additionally, I'm unsure about how the loss is aggregated across the 8 GPUs available on my school server. It's unclear whether the loss displayed during training is the sum of losses from all GPUs or if it's a mean value.

Could you provide insights into the expected training duration for this vectorization method and how many epochs does it takes?

Thanks in advance for your help!

nopperl commented 3 months ago

Hi!

Thanks for taking interest!

marked-clean.yaml is the correct starting point. It is definitely not the expected behaviour that the loss remains constant. There should be a drastic decrease already in the first iterations. Could you send me the training output?

Additionally, I'm unsure about how the loss is aggregated across the 8 GPUs available on my school server. It's unclear whether the loss displayed during training is the sum of losses from all GPUs or if it's a mean value.

I only trained the models on a single GPU, so I don't know if the code even works with multiple GPUs. What command did you run for this? If you ran it with DDP, the loss will not be aggregated. Instead, every rank (GPU process) has its own local loss value (which is then concurrently written to the log, probably leading to a race condition). But depending on what you're trying to achieve, getting distributed training to work correctly might be overkill. The model is small enough to fit on one GPU (with the batch_size: 64 in marked_clean.yaml it takes up roughly 10GB) and shouldn't take longer than a few hours to display initial good results.

Could you provide insights into the expected training duration for this vectorization method and how many epochs does it takes?

The provided model was trained for roughly 250 000 iterations (around 780 epochs), with each iteration taking 2.5 seconds on average. In total, this took about a week. However, you can already get good performance after training for less than 50 000 iterations (around 150 epochs), which would only take about 30 hours.

For a reference on training details, you can also check out Section 4.1.1 (especially Figure 4.4) and Section 4.4 of my thesis.

IoanaVoica20 commented 3 months ago

Thank you for your response.

I use this command to run the training: python3 marked_lineart_vec/train.py -c configs/marked-clean.yaml And I use ddp like this in train script: runner = Trainer(strategy='ddp', ...#the rest of the args)

This is the current progress of my training (epoch 130) image image Initially, I considered the number of epochs used for training, specifically 130 epochs, to be quite high. However, based on your response, I understand that I need to allow more time for the model to achieve better results.

nopperl commented 3 months ago

I use this command to run the training: python3 marked_lineart_vec/train.py -c configs/marked-clean.yaml

OK, that's correct!

And I use ddp like this in train script: runner = Trainer(strategy='ddp', ...#the rest of the args)

Good to know that you tested running it with DDP. Did you change anything else in the code to make it work? I was initially skeptical whether it would work without modifications, so it would be great to know it works out of the box.

This is the current progress of my training (epoch 130)

The validation IoU looks good and is within the expected range for 14 000 iterations. The loss looks a bit different than my runs.

A potential cause for your results could be the following: I think that the logged IoU is an average over all ranks, because the code uses torchmetrics.IoU, which automatically syncs the metric (see https://lightning.ai/docs/torchmetrics/stable/pages/lightning.html#logging-torchmetrics). However, the loss is not a torchmetrics.Metric, so it is not automatically synced by default. For this, sync_dist=True might need to be set in the self.log(...) calls (see https://lightning.ai/docs/pytorch/stable/visualize/logging_advanced.html#sync-dist).

I would recommend to continue running the training and simply monitor the validation IoU.

Initially, I considered the number of epochs used for training, specifically 130 epochs, to be quite high. However, based on your response, I understand that I need to allow more time for the model to achieve better results.

You're right, at first glance, the amount of epochs seems unusually long and would usually lead to overfitting. However, this is because the dataset size is structured differently than usual. While the model in theory gets the same input images every epoch, it will (almost) always get different target curves (one of potentially thousands of curves per image). Together with data augmentation, this increases the effective dataset size by multiple magnitudes.

IoanaVoica20 commented 2 months ago

Good to know that you tested running it with DDP. Did you change anything else in the code to make it work? I was initially skeptical whether it would work without modifications, so it would be great to know it works out of the box.

I didn't change anything else, as I remember, and it works fine.

Thanks a lot for your help and explanations. Also, congrats on all your hard work in this project and in your thesis as well. It's really impressive!

nopperl commented 2 months ago

Great!

Also thanks for your kind words! Just out of curiosity, how did you find this project and what are you trying to achieve?

IoanaVoica20 commented 2 months ago

To provide some context, I'm currently working on my final undergraduate thesis. My project involves developing an application that processes scanned technical drawings, such as sketches of house plans. The application's workflow includes cleaning up the image, applying a vectorization method, and then parsing the resulting SVG file to extract the elements and automatically create them in an AutoCAD project. Essentially, I'm building a converter that transforms raster images of technical sketches into AutoCAD projects, with the primary focus being on vectorization.

I stumbled upon your project while researching well-documented vectorization methods. It took quite a bit of searching before I found it, and I experimented with various approaches before settling on this one. Currently, I'm working on adapting the code to suit the specific needs of my project. Additionally, I'm interested in implementing a compression method like knowledge distillation to enhance the efficiency of the application.

Best regards

nopperl commented 2 months ago

Oh, that sounds interesting! I'm curious if this method will turn out to be exact enough.

Definitely the biggest factor in model performance was the training data, so to improve results I recommend to get a good amount of quality data. If you haven't already checked, there are some technical sketch datasets I encountered on my research, but I'm not sure how exactly they fit your domain:

(Just as an aside, I also had to add amateur sketches to my training dataset, otherwise the model diverged. So it might be necessary to include a significant amount of simpler sketches as well...)

I'm really interested in how this method works in the end, so let me know of your research results should you choose to continue with this! Also, feel free to ask again if you have further questions or need advice.

IoanaVoica20 commented 2 months ago

Thank you for the recommendations. I've already used some of the datasets you provided when training the image cleaning network. The vectorization is working quite well, but I need to improve it for images with predominantly straight lines. For this purpose, I've created a synthetic dataset, similar to what I did for image cleaning, and I'm fine-tuning the trained model.

Here is an example image from the fine-tuning dataset.

24

I'm still working on processing the dataset to make it as similar as possible to the one I originally trained with. Currently, I'm encountering an error like: "RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn", but I'm striving to quickly figure out what's causing it and restart the training as soon as possible.

Thanks again for all the recommendations!