alexandre01 / deepsvg

[NeurIPS 2020] Official code for the paper "DeepSVG: A Hierarchical Generative Network for Vector Graphics Animation". Includes a PyTorch library for deep learning with SVG data.
https://www.reshot.ai
MIT License
981 stars 99 forks source link

Training classifier on bottleneck #15

Open tsaxena opened 3 years ago

tsaxena commented 3 years ago

@alexandre01 I am trying to train a classifier on the icons using the bottleneck embeddings that I get by model inference using pretrained model. In some cases though it doesnt seem to work. I used the code in your latent_ops notebook and encode each icon. In some cases I get this error The size of tensor a (10) must match the size of tensor b (8) at non-singleton dimension 0 Am I missing something?

tsaxena commented 3 years ago

I am using the following methods to encode icon, but does seem to work on some of the icon indices.

` def encode(data): model_args = batchify((data[key] for key in cfg.model_args), device) with torch.no_grad(): z = model(*model_args, encode_mode=True) return z

def encode_icon(idx): data = dataset.get(id=idx, random_aug=False) return encode(data) `

alexandre01 commented 3 years ago

Hello @tsaxena! Yes unfortunately this is due to the fact that the pretrained model was trained with a maximum of 8 paths per SVG. Since an index embedding is used in the model, this means one cannot perform inference using a larger amount.

You'd need to train a model with more paths or filter your classification dataset to eight paths.

tsaxena commented 3 years ago

Thanks for the prompt reply @alexandre01 . So does that mean you did not use all 100k icons for training the pretrained model?

alexandre01 commented 3 years ago

Yes, the dataset is about 100k icons, but because of time constraints the pretrained model was only trained on a filtered subset. And I don't have access to the GPU server I used anymore.

tsaxena commented 3 years ago

I eventually want to fine tune the network on svgs that will have more than 8 paths. Do you suggest training from scratch? From what I understand, the max number of SVG paths is a configuration parameter that can be changed.

pwichmann commented 3 years ago

I am not @alexandre01. But what you say is correct, I think.

The max number of paths is a config parameter and can be changed in deepsvg/model/confg.py:

self.max_num_groups = 8          # Number of paths (N_P)

There are also individual configs in configs/deepsvg/ that overwrite this variable. You may need to increase the number of paths here as well.

You would need to retrain to cope with a larger number of paths. If you have a beefy GPU, retraining does not take very long. On my RTX 3090, I can retrain from scratch within hours.

What is your use case?