amazon-science / unconditional-time-series-diffusion

Official PyTorch implementation of TSDiff models presented in the NeurIPS 2023 paper "Predict, Refine, Synthesize: Self-Guiding Diffusion Models for Probabilistic Time Series Forecasting"
Apache License 2.0
143 stars 27 forks source link

Train on custom dataset #7

Open khabalghoul opened 10 months ago

khabalghoul commented 10 months ago

Hi! How are you?

I found that tsdiff could be a great tool for generating eeg data. I have a dataset containing the channels measurements from an eeg obtained in an experiment and I would like to train your model with this data. How should I do in order to train your model with a custom dataset?

Thanks!

abdulfatir commented 10 months ago

Hi @tomyjara!

You can use something like this to build a custom dataset.

  1. Create a JSON lines file with your time series data. Basically every line has one time series in JSON format with two keys, start (the start time stamp) and target (the actual time series). I have attached an example file. Note that the time series are not required to have the same start or length.

  2. Use this function to load the file as a GluonTS dataset.

from pathlib import Path

from gluonts.dataset.split import split
from gluonts.dataset.common import (
    MetaData,
    TrainDatasets,
    FileDataset,
)

def get_custom_dataset(
    jsonl_path: Path,
    freq: str,
    prediction_length: int,
    split_offset: int = None,
):
    """Creates a custom GluonTS dataset from a JSONLines file and
    give parameters.

    Parameters
    ----------
    jsonl_path
        Path to a JSONLines file with time series
    freq
        Frequency in pandas format
        (e.g., `H` for hourly, `D` for daily)
    prediction_length
        Prediction length
    split_offset, optional
        Offset to split data into train and test sets, by default None

    Returns
    -------
        A gluonts dataset
    """
    if split_offset is None:
        split_offset = -prediction_length

    metadata = MetaData(freq=freq, prediction_length=prediction_length)
    test_ts = FileDataset(jsonl_path, freq)
    train_ts, _ = split(test_ts, offset=split_offset)
    dataset = TrainDatasets(metadata=metadata, train=train_ts, test=test_ts)
    return dataset
  1. This get_custom_dataset can be used as a replacement for https://github.com/amazon-science/unconditional-time-series-diffusion/blob/50f52da1c583d2eece4da8e933f34b73dc249a75/bin/train_model.py#L135
  2. Modify the default config appropriately, especially the context length, lags, etc.

Thanks @marcelkollovieh for helping with the response!

gulugulu888 commented 2 weeks ago

tsdiff) rrr@rr:~/unconditional-time-series-diffusion$ python bin/train_model.py -c configs/train_fdr.yaml DEBUG:root:Before importing pykeops... DEBUG:root:After importing pykeops! INFO:uncond_ts_diff.arch.s4:Pykeops installation found. WARNING: Skipping key sampler_params! WARNING:root:Cannot infer loader for /home/h/unconditional-time-series-diffusion/data/fdr/CAS/dummy_custom_data.json:Zone.Identifier. WARNING:root:Cannot infer loader for /home/h/unconditional-time-series-diffusion/data/fdr/CAS/train - 副本.json:Zone.Identifier. WARNING:root:Cannot infer loader for /home/h/unconditional-time-series-diffusion/data/fdr/CAS/dummy_custom_data.json:Zone.Identifier. WARNING:root:Cannot infer loader for /home/h/unconditional-time-series-diffusion/data/fdr/CAS/train - 副本.json:Zone.Identifier. GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs INFO:bin/train_model.py:Logging to ./lightning_logs/version_44 /home/h/anaconda3/envs/tsdiff/lib/python3.8/site-packages/pytorch_lightning/trainer/configuration_validator.py:108: PossibleUserWarning: You defined a validation_step but have no val_dataloader. Skipping val loop. rank_zero_warn( You are using a CUDA device ('NVIDIA GeForce RTX 4060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set torch.set_float32_matmul_precision('medium' | 'high') which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] ┏━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━┓ ┃ ┃ Name ┃ Type ┃ Params ┃ ┡━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━┩ │ 0 │ scaler │ MeanScaler │ 0 │ │ 1 │ embedder │ FeatureEmbedder │ 1 │ │ 2 │ backbone │ BackboneModel │ 193 K │ └───┴──────────┴─────────────────┴────────┘ Trainable params: 193 K Non-trainable params: 0 Total params: 193 K Total estimated model params size (MB): 0 DEBUG:fsspec.local:open file: /home/h/unconditional-time-series-diffusion/lightning_logs/version_44/hparams.yaml DEBUG:fsspec.local:open file: /home/h/unconditional-time-series-diffusion/lightning_logs/version_44/hparams.yaml Epoch 0/99 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0/-- 0:00:00 • -:--:-- 0.00it/s Traceback (most recent call last): File "/home/h/anaconda3/envs/tsdiff/lib/python3.8/site-packages/gluonts/dataset/jsonl.py", line 127, in iter yield json.loads(line) orjson.JSONDecodeError: unexpected end of data: line 2 column 1 (char 3)

During handling of the above exception, another exception occurred:

Traceback (most recent call last): File "bin/train_model.py", line 286, in main(config=config, log_dir=args.out_dir) File "bin/train_model.py", line 224, in main trainer.fit(model, train_dataloaders=data_loader) File "/home/h/anaconda3/envs/tsdiff/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 608, in fit call._call_and_handle_interrupt( File "/home/h/anaconda3/envs/tsdiff/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt return trainer_fn(*args, kwargs) File "/home/h/anaconda3/envs/tsdiff/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _fit_impl self._run(model, ckpt_path=self.ckpt_path) File "/home/h/anaconda3/envs/tsdiff/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1112, in _run results = self._run_stage() File "/home/h/anaconda3/envs/tsdiff/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1191, in _run_stage self._run_train() File "/home/h/anaconda3/envs/tsdiff/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1214, in _run_train self.fit_loop.run() File "/home/h/anaconda3/envs/tsdiff/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 199, in run self.advance(*args, *kwargs) File "/home/h/anaconda3/envs/tsdiff/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 267, in advance self._outputs = self.epoch_loop.run(self._data_fetcher) File "/home/h/anaconda3/envs/tsdiff/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 194, in run self.on_run_start(args, kwargs) File "/home/h/anaconda3/envs/tsdiff/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 160, in on_runstart = iter(data_fetcher) # creates the iterator inside the fetcher File "/home/h/anaconda3/envs/tsdiff/lib/python3.8/site-packages/pytorch_lightning/utilities/fetching.py", line 180, in iter self.prefetching() File "/home/h/anaconda3/envs/tsdiff/lib/python3.8/site-packages/pytorch_lightning/utilities/fetching.py", line 243, in prefetching self._fetch_next_batch(iterator) File "/home/h/anaconda3/envs/tsdiff/lib/python3.8/site-packages/pytorch_lightning/utilities/fetching.py", line 280, in _fetch_next_batch batch = next(iterator) File "/home/h/anaconda3/envs/tsdiff/lib/python3.8/site-packages/pytorch_lightning/trainer/supporters.py", line 571, in next return self.request_next_batch(self.loader_iters) File "/home/h/anaconda3/envs/tsdiff/lib/python3.8/site-packages/pytorch_lightning/trainer/supporters.py", line 583, in request_next_batch return apply_to_collection(loader_iters, Iterator, next) File "/home/h/anaconda3/envs/tsdiff/lib/python3.8/site-packages/lightning_utilities/core/apply_func.py", line 64, in apply_to_collection return function(data, *args, **kwargs) File "/home/h/anaconda3/envs/tsdiff/lib/python3.8/site-packages/gluonts/itertools.py", line 181, in iter yield from itertools.islice(self.iterable, self.length) File "/home/h/anaconda3/envs/tsdiff/lib/python3.8/site-packages/gluonts/transform/_base.py", line 103, in iter yield from self.transformation( File "/home/h/anaconda3/envs/tsdiff/lib/python3.8/site-packages/gluonts/transform/_base.py", line 124, in call for data_entry in data_it: File "/home/h/anaconda3/envs/tsdiff/lib/python3.8/site-packages/gluonts/dataset/loader.py", line 37, in call yield from batcher(data, self.batch_size) File "/home/h/anaconda3/envs/tsdiff/lib/python3.8/site-packages/gluonts/itertools.py", line 100, in get_batch return list(itertools.islice(it, batch_size)) File "/home/h/anaconda3/envs/tsdiff/lib/python3.8/site-packages/gluonts/transform/_base.py", line 178, in call for data_entry in data_it: File "/home/h/anaconda3/envs/tsdiff/lib/python3.8/site-packages/gluonts/itertools.py", line 77, in iter for el in self.iterable: File "/home/h/anaconda3/envs/tsdiff/lib/python3.8/site-packages/gluonts/itertools.py", line 125, in iter for element in self.iterable: File "/home/h/anaconda3/envs/tsdiff/lib/python3.8/site-packages/gluonts/transform/_base.py", line 103, in iter yield from self.transformation( File "/home/h/anaconda3/envs/tsdiff/lib/python3.8/site-packages/gluonts/transform/_base.py", line 124, in call for data_entry in data_it: File "/home/h/anaconda3/envs/tsdiff/lib/python3.8/site-packages/gluonts/transform/_base.py", line 124, in call for data_entry in data_it: File "/home/h/anaconda3/envs/tsdiff/lib/python3.8/site-packages/gluonts/transform/_base.py", line 124, in call for data_entry in data_it: [Previous line repeated 8 more times] File "/home/h/anaconda3/envs/tsdiff/lib/python3.8/site-packages/gluonts/dataset/init.py", line 56, in iter yield from self.iter_sequential() File "/home/h/anaconda3/envs/tsdiff/lib/python3.8/site-packages/gluonts/dataset/init.py", line 47, in iter_sequential yield from dataset File "/home/h/anaconda3/envs/tsdiff/lib/python3.8/site-packages/gluonts/dataset/jsonl.py", line 129, in iter raise GluonTSDataError( gluonts.exceptions.GluonTSDataError: Could not read json line 0, b'{\r\n'

I've used dummy json as train.json, but the error comes out as above