proceduralia / pytorch-GAN-timeseries

GANs for time series generation in pytorch
273 stars 77 forks source link
finance generative-adversarial-network lstm pytorch sequence-generation time-series wavenet

Financial time series generation using GANs

This repository contains the implementation of a GAN-based method for real-valued financial time series generation. See for instance Real-valued (Medical) Time Series Generation with Recurrent Conditional GANs.

Main features:

During conditional training, daily deltas that are given as additional input to the generator are sampled from a Gaussian distribution estimated from real data via maximum likelihood.

Some words on the dataset

Considering the original data provided in csv format, the values for the time series are obtained from the feature btp_price. Minimal preprocessing, including normalization in the range [-1,1], is done inside btp_dataset.py. The resulting dataset has 173 sequences of length 96, for an overall tensor shape of (173 x 96 x 1). If you use a dataset that is not compatible with this preprocessing, you can just write your own loader.

Project structure

The files and directories composing the project are:

By default, during training, model weights are saved into the checkpoints/ directory, snapshots of generated series into images/ and tensorboard logs into log/.

Use:

tensorboard --logdir log

from inside the project directory to run tensoboard on the default port (6006).

Examples

Run training with recurrent generator and convolutional discriminator, conditioning generator on deltas and alternating adversarial and supervised optimization:

python main.py --dataset_path some_dataset.csv --delta_condition --gen_type lstm --dis_type cnn --alternate --run_tag cnn_dis_lstm_gen_alternte_my_first_trial

Generate fake dataset prova.npy using deltas contained in delta_trial.txt and model trained for 70 epochs:

python generate_dataset.py --delta_path delta_trial.txt --checkpoint_path checkpoints/cnn_conditioned_alternate1_netG_epoch_70.pth --output_path prova.npy

Finetune checkpoint of generator with supervised training:

python finetune_model.py --checkpoint checkpoints/cnn_dis_lstm_gen_noalt_new_netG_epoch_39.pth --output_path finetuned.pth

Insights and directions for improvement