wangqiang9 / SketchKnitter

About PyTorch implementation of SketchKnitter: Vectorized Sketch Generation with Diffusion Models, ICLR 2023, Spotlight.
MIT License
58 stars 9 forks source link
diffusion diffusion-model diffusion-models pytorch sketch sketches sketching

SketchKnitter

Python 3.6 Pytorch 1.6 MIT License open issues

In this repository, you can find the PyTorch implementation of SketchKnitter: Vectorized Sketch Generation with Diffusion Models, ICLR 2023, Spotlight.

Authors: Qiang Wang, Haoge Deng, Yonggang Qi, Da Li, Yi-Zhe Song. Beijing University of Posts and Telecommunications, Samsung AI Centre Cambridge, University of Surrey.

Abstract: We show vectorized sketch generation can be identified as a reversal of the stroke deformation process. This relationship was established by means of a diffusion model that learns data distributions over the stroke-point locations and pen states of real human sketches. Given randomly scattered stroke-points, sketch generation becomes a process of deformation-based denoising, where the generator rectifies positions of stroke points at each timestep to converge at a recognizable sketch. A key innovation was to embed recognizability into the reverse time diffusion process. It was observed that the estimated noise during the reversal process is strongly correlated with sketch classification accuracy. An auxiliary recurrent neural network (RNN) was consequently used to quantify recognizability during data sampling. It follows that, based on the recognizability scores, a sampling shortcut function can also be devised that renders better quality sketches with fewer sampling steps. Finally it is shown that the model can be easily extended to a conditional generation framework, where given incomplete and unfaithful sketches, it yields one that is more visually appealing and with higher recognizability.

Fig.1

Datasets

Please go to the QuickDraw official website to download the datasets. The class list used in the paper: moon, airplane, fish, umbrella, train, spider, shoe, apple, lion, bus, you can also replace it with any other category. Each category class is stored in its own file, and contains training/validation/test set sizes of 70000/2500/2500 examples.

In addition to the QuickDraw dataset, you can train the model on any dataset, but please pay attention to organizing the dataset into vector format and packaging it into .npz file. In the case of less data sets, please pay attention to over-fitting. If you want to create your own dataset, you can follow the official tutorial of SketchRNN.

Installation

The requirements of this repo can be found in requirements.txt.

conda create -n sketchknitter python=3.7
conda activate sketchknitter
pip install -r requirements.txt

Train and Inference

Haperparameters

Here is a list of full options for the model:

lr,                   # learning rate.
log_dir,              # save log path.
dropout,              # dropout rate.
use_fp16,             # whether to use mixed precision training.
ema_rate,             # comma-separated list of EMA values
category,             # list of category name to be trained.
data_dir,             # the data sets path.
use_ddim,             # choose whether to use DDIM or DDPM
save_path,            # path to save vector results.
pen_break,            # determines the experience value of stroke break.
image_size,           # the max numbers of datasets.
model_path,           # path to save the trained model checkpoint.
class_cond,           # whether to use guidance technology.
batch_size,           # batch size of training.
emb_channels,         # Unet embedding channel numbers.
num_channels,         # the numbers of channels in Unet backbone.
out_channels,         # output channels in Unet. 
save_interval,        # saving models interval.
noise_schedule,       # the method of adding noise is linear by default.
num_res_blocks,       # numbers of resnet blocks in Unet backbone.
diffusion_steps,      # diffusion steps in the forward process.
schedule_sampler,     # the schedule of sampler.
fp16_scale_growth,    # the mixed precision scale growth.
use_scale_shift_norm, # whether to use scale shift norm. 

Train Example Usage:

bash train.sh

Inference Example Usage:

bash sample.sh

Visualization and Evaluation

sample.py obtained results are stored in .npz file format, which can be directly used to calculate quantitative indicators such as FID, IS, GS, etc. You can also visualize the results for qualitative experiments, refer to the results of Google Brain for the visualization method.

To calculate of FID and IS, refer to the official code. The calculation of Geometry Score(GS) can directly use data in vector format, please go to the official website for instructions.

Results

Simple FID↓ GS↓ Prec↑ Rec↑
SketchPix2seq 13.3 7.0 0.40 0.79
SketchHealer 10.3 5.9 0.45 0.81
SketchRNN 10.8 5.4 0.44 0.82
Diff-HW 13.3 6.8 0.42 0.81
SketchODE 11.5 9.4 0.48 0.74
Ours (full 1000 steps) 6.9 3.4 0.52 0.88
Ours (r-Shortcut, S=30) 7.4 3.9 0.47 0.87
Ours (Linear-DDIMs, S=30) 11.9 6.4 0.38 0.81
Ours (Quadratic-DDIMs, S=30) 12.3 6.6 0.41 0.79
Ours (Abs) 20.7 12.1 0.18 0.55
Ours (Point-Shuffle) 9.5 5.3 0.35 0.72
Ours (Stroke-Shuffle) 8.2 3.8 0.36 0.74

Fig 4

Only part of the results are listed here. For more detailed results, please see our paper and supplementary materials.

License

This project is released under the MIT License.

Citation

If you find this repository useful for your research, please use the following.

@inproceedings{wangsketchknitter,
  title={SketchKnitter: Vectorized Sketch Generation with Diffusion Models},
  author={Wang, Qiang and Deng, Haoge and Qi, Yonggang and Li, Da and Song, Yi-Zhe},
  booktitle={The Eleventh International Conference on Learning Representations}
}

Acknowledgements

Contact

If you have any questions about the code, please contact wanqqiang@bupt.edu.cn