mmcdermott / EventStreamGPT

Dataset and modelling infrastructure for modelling "event streams": sequences of continuous time, multivariate events with complex internal dependencies.
https://eventstreamml.readthedocs.io/en/latest/
MIT License
98 stars 16 forks source link

Pretrain crashes during collating static and dynamic data in pytorch_dataset #66

Closed juancq closed 1 year ago

juancq commented 1 year ago

When running pretraining, the program crashes in the middle of the first epoch. The error seems to be coming from here:

https://github.com/mmcdermott/EventStreamGPT/blob/bb689ae8f95aef2ebb243d0ba06e423eefee9d90/EventStream/data/pytorch_dataset.py#L534-L534

Here is the error trace:

Traceback (most recent call last):
  File "H:\pyscripts\apdc_gpt\pretrain.py", line 41, in main
    return train(cfg)
  File "H:\pyscripts\apdc_gpt\EventStream\utils.py", line 381, in wrap
    raise ex
  File "H:\pyscripts\apdc_gpt\EventStream\utils.py", line 376, in wrap
    fn_return = task_func(*args, **kwargs)
  File "H:\pyscripts\apdc_gpt\EventStream\transformer\lightning_modules\generative_modeling.py", line 678, in train
    trainer.fit(model=LM, train_dataloaders=train_dataloader, val_dataloaders=tuning_dataloader)
  File "D:\Users\juanqa\Anaconda3\envs\esgpt\lib\site-packages\lightning\pytorch\trainer\trainer.py", line 531, in fit
    call._call_and_handle_interrupt(
  File "D:\Users\juanqa\Anaconda3\envs\esgpt\lib\site-packages\lightning\pytorch\trainer\call.py", line 42, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "D:\Users\juanqa\Anaconda3\envs\esgpt\lib\site-packages\lightning\pytorch\trainer\trainer.py", line 570, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "D:\Users\juanqa\Anaconda3\envs\esgpt\lib\site-packages\lightning\pytorch\trainer\trainer.py", line 975, in _run
    results = self._run_stage()
  File "D:\Users\juanqa\Anaconda3\envs\esgpt\lib\site-packages\lightning\pytorch\trainer\trainer.py", line 1018, in _run_stage
    self.fit_loop.run()
  File "D:\Users\juanqa\Anaconda3\envs\esgpt\lib\site-packages\lightning\pytorch\loops\fit_loop.py", line 201, in run
    self.advance()
  File "D:\Users\juanqa\Anaconda3\envs\esgpt\lib\site-packages\lightning\pytorch\loops\fit_loop.py", line 354, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "D:\Users\juanqa\Anaconda3\envs\esgpt\lib\site-packages\lightning\pytorch\loops\training_epoch_loop.py", line 133, in run
    self.advance(data_fetcher)
  File "D:\Users\juanqa\Anaconda3\envs\esgpt\lib\site-packages\lightning\pytorch\loops\training_epoch_loop.py", line 189, in advance
    batch = next(data_fetcher)
  File "D:\Users\juanqa\Anaconda3\envs\esgpt\lib\site-packages\lightning\pytorch\loops\fetchers.py", line 136, in __next__
    self._fetch_next_batch(self.dataloader_iter)
  File "D:\Users\juanqa\Anaconda3\envs\esgpt\lib\site-packages\lightning\pytorch\loops\fetchers.py", line 150, in _fetch_next_batch
    batch = next(iterator)
  File "D:\Users\juanqa\Anaconda3\envs\esgpt\lib\site-packages\lightning\pytorch\utilities\combined_loader.py", line 284, in __next__
    out = next(self._iterator)
  File "D:\Users\juanqa\Anaconda3\envs\esgpt\lib\site-packages\lightning\pytorch\utilities\combined_loader.py", line 65, in __next__
    out[i] = next(self.iterators[i])
  File "D:\Users\juanqa\Anaconda3\envs\esgpt\lib\site-packages\torch\utils\data\dataloader.py", line 633, in __next__
    data = self._next_data()
  File "D:\Users\juanqa\Anaconda3\envs\esgpt\lib\site-packages\torch\utils\data\dataloader.py", line 1345, in _next_data
    return self._process_data(data)
  File "D:\Users\juanqa\Anaconda3\envs\esgpt\lib\site-packages\torch\utils\data\dataloader.py", line 1371, in _process_data
    data.reraise()
  File "D:\Users\juanqa\Anaconda3\envs\esgpt\lib\site-packages\torch\_utils.py", line 644, in reraise
    raise exception
TypeError: Caught TypeError in DataLoader worker process 2.
Original Traceback (most recent call last):
  File "D:\Users\juanqa\Anaconda3\envs\esgpt\lib\site-packages\torch\utils\data\_utils\worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "D:\Users\juanqa\Anaconda3\envs\esgpt\lib\site-packages\torch\utils\data\_utils\fetch.py", line 54, in fetch
    return self.collate_fn(data)
  File "D:\Users\juanqa\Anaconda3\envs\esgpt\lib\site-packages\mixins\timeable.py", line 101, in wrapper_timing
    out = fn(self, *args, **kwargs)
  File "H:\pyscripts\apdc_gpt\EventStream\data\pytorch_dataset.py", line 710, in collate
    return self.__static_and_dynamic_collate(batch)
  File "H:\pyscripts\apdc_gpt\EventStream\data\pytorch_dataset.py", line 543, in __static_and_dynamic_collate
    max_n_static = max(len(e["static_indices"]) for e in batch)
  File "H:\pyscripts\apdc_gpt\EventStream\data\pytorch_dataset.py", line 543, in <genexpr>
    max_n_static = max(len(e["static_indices"]) for e in batch)
TypeError: object of type 'NoneType' has no len()

I have not been able to consistently replicate it. I got pretraining to run to completion once. I am unsure what breaks it, whether it was the number of dataloader worker processes, the batch size, or a combination of other factors.

juancq commented 1 year ago

I got the pretraining script to complete the first epoch without errors when using num_data_loader_workers < 3. With num_data_loader_workers equal to 3 or 4, I was able to replicate the error above (occurring during the first epoch).

I am using a batch size of 32 and a validation batch size of 32.

mmcdermott commented 1 year ago

@juancq is this on the current main branch of the code? Or on the dev branch, or a modified version? And is it possible some of your subjects don't have any static data observed? That is what is causing the issue, based on the error, but I've not encountered that situation on any of my datasets so want to make sure that is expected. Either way, it should be a relatively simple fix; I've pushed some code to possibly fix it here: https://github.com/mmcdermott/EventStreamGPT/compare/dev...fix_static_data_bug?expand=1, though as I don't have a test case for this issue I can't be sure, but you can try this branch out and see if it works? Note that this branch, being derived from dev, has some recent changes to dataset structure, so it may conflict with any existing datasets you have created and saved to disk. If that is an issue, I can help migrate your datasets over, though it may be best just to re-run your dataset creation script to get a new slice, unless you are too resource constrained.

Additionally you can try adding just the delta shown in the link above to your local code and see if that solves it. If it does, it'd be great to get a test case for this as well.

juancq commented 1 year ago

@mmcdermott This was on the dev branch. Your code 29c29f13d732468ccb217b559439b2abc41d25b9 fixed the issue.

mmcdermott commented 1 year ago

Great! Then this should be fixed in dev by #67 . Let me know if you're still seeing any issues.