dasayan05 / chirodiff

ChiroDiff: Modelling chirographic data with Diffusion Models
https://ayandas.me/chirodiff
MIT License
14 stars 2 forks source link

ChiroDiff: Modelling chirographic data with Diffusion Models

Accepted at International Conference on Learning Representation (ICLR) 2023

Authors: Ayan Das, Yongxin Yang, Timothy Hospedales, Tao Xiang, Yi-Zhe Song


[OpenReview], [arXiv] & [Project Page]

Abstract: Generative modelling over continuous-time geometric constructs, a.k.a such as handwriting, sketches, drawings etc., have been accomplished through autoregressive distributions. Such strictly-ordered discrete factorization however falls short of capturing key properties of chirographic data -- it fails to build holistic understanding of the temporal concept due to one-way visibility (causality). Consequently, temporal data has been modelled as discrete token sequences of fixed sampling rate instead of capturing the true underlying concept. In this paper, we introduce a powerful model-class namely "Denoising Diffusion Probabilistic Models" or DDPMs for chirographic data that specifically addresses these flaws. Our model named "ChiroDiff", being non-autoregressive, learns to capture holistic concepts and therefore remains resilient to higher temporal sampling rate up to a good extent. Moreover, we show that many important downstream utilities (e.g. conditional sampling, creative mixing) can be flexibly implemented using ChiroDiff. We further show some unique use-cases like stochastic vectorization, de-noising/healing, abstraction are also possible with this model-class. We perform quantitative and qualitative evaluation of our framework on relevant datasets and found it to be better or on par with competing approaches.


Running the code

The instructions below guide you regarding running the codes in this repository.

Table of contents:

  1. Environment and libraries
  2. Data preparation
  3. Training
  4. Inference

Environment & Libraries

Running the code may require some libraries slightly outdated. The full list is provided as a requirements.txt in this repo. Please create a virtual environment with conda or venv and run

(myenv) $ pip install -r requirements.txt

Data preparation

You can feed the data in one of two ways -- "unpacked" and "unpacked and preprocessed". The first one will dynamically load data from individual files, whereas the later packs preprocessed input into one single .npz file -- increasing training speed.

Training & Sampling

There are multiple training "modes" corresponding to the model type (unconditional, sequence conditioned etc).

threeseqdel # unconditional model with delta (velocity) sequence
threeseqdel_pointcloudcond # conditioned on pointcloud representation
threeseqdel_classcond # conditioned on class
threeseqdel_threeseqdelcond # conditioned on self

threeseqabs # unconditional model with absolute (position) sequence
threeseqabs_pointcloudcond # conditioned on pointcloud representation
threeseqabs_classcond # conditioned on class
threeseqabs_threeseqabscond # conditioned on self

Note: For simplicity, we provided a config.yml file where all possible command line option can be altered. Then run the main script as

(myenv) $ python main.py fit --config config.yml --model.arch_layer 3 --model.noise_T 100 ...

You will also need wandb for logging. Please use your own account and fill the correct values of --trainer.logger.init_args.{entity, project} in the config.yml file. You may also remove the wandb logger entirely and replace with another logger of your choice. In that case, you might have to modify few lines of codes.

While training, the script will save the full config of the run, a "best model" and a "last model". Once trained, use the saved model (saved every 300 epoch) and full configuration using the --ckpt_path and --config argument like so

(myenv) $ python main.py test --config ./logs/test-run/config.yaml --ckpt_path ./logs/test-run/.../checkpoints/model.ckpt --limit_test_batches 1

By default, the testing phase will write some vizualization helpful for inspection. For example, a generation results and a diffusion process vizualization. Test time option have --test_ prefixes. Feel free to play around with them.

(myenv) $ python main.py test --config ... --ckpt_path ... \
            --test_sampling_algo ddpm \
            --test_variance_strength 0.75 \
            --text_viz_process backward \
            --test_save_everything 1

You can site the paper as

@inproceedings{das2023chirodiff,
    title={ChiroDiff: Modelling chirographic data with Diffusion Models},
    author={Ayan Das and Yongxin Yang and Timothy Hospedales and Tao Xiang and Yi-Zhe Song},
    booktitle={The Eleventh International Conference on Learning Representations },
    year={2023},
    url={https://openreview.net/forum?id=1ROAstc9jv}
}

Notes:

  1. This repository is a part of our research codebase and may therefore contain codes/options that are not part of the paper.
  2. This repo may also contain some implmenetation details that has been upgraded since the submission of the paper.
  3. The README is still incomplete and I will add more info when I get time. You may try different settings yourself.
  4. The default parameters might not match the ones in the paper. Feel free to change play with them.