mosaicml / diffusion

Apache License 2.0
675 stars 70 forks source link

Error when training with local precomputed features #33

Open tungdq212 opened 1 year ago

tungdq212 commented 1 year ago

Bug when local training with LocalDataset

Here is my config (without some personal paths), run for mosaicml's diffusion:

algorithms:
  low_precision_groupnorm:
    attribute: unet
    precision: amp_fp16
  low_precision_layernorm:
    attribute: unet
    precision: amp_fp16
model:
  _target_: diffusion.models.models.stable_diffusion_2
  model_name: runwayml/stable-diffusion-v1-5
  pretrained: true
  precomputed_latents: true
  encode_latents_in_fp16: true
  # fsdp: false
  fsdp: true
  val_metrics:
    - _target_: torchmetrics.MeanSquaredError
  val_guidance_scales: []
  loss_bins: []
dataset:
  train_batch_size: 2048 # TODO: explore composer config
  eval_batch_size: 16 # Should be 8 per device
  train_dataset:
    _target_: diffusion.datasets.pixta.pixta.build_custom_dataloader
    data_path: ...
    feature_dim: 32
    num_workers: 8
    pin_memory: false
  eval_dataset:
    _target_: diffusion.datasets.pixta.pixta.build_custom_dataloader
    data_path: ...
    feature_dim: 32
    num_workers: 8
    pin_memory: false
optimizer:
  _target_: torch.optim.AdamW
  lr: 1.0e-5
  weight_decay: 0.01
scheduler:
  _target_: composer.optim.ConstantWithWarmupScheduler
  t_warmup: 10ba
logger:
  wandb:
    _target_: composer.loggers.wandb_logger.WandBLogger
    name: ${name}
    project: ${project}
    group: ${name}
callbacks:
  speed_monitor:
    _target_: composer.callbacks.speed_monitor.SpeedMonitor
    window_size: 10
  lr_monitor:
    _target_: composer.callbacks.lr_monitor.LRMonitor
  memory_monitor:
    _target_: composer.callbacks.memory_monitor.MemoryMonitor
  runtime_estimator:
    _target_: composer.callbacks.runtime_estimator.RuntimeEstimator
  optimizer_monitor:
    _target_: composer.callbacks.OptimizerMonitor
trainer:
  _target_: composer.Trainer
  device: gpu
  max_duration: 10ep
  eval_interval: 2ep
  device_train_microbatch_size: 40
  run_name: ${name}
  seed: ${seed}
  scale_schedule_ratio: ${scale_schedule_ratio}
  save_folder:  outputs/${project}/${name}
  save_interval: 5ep
  save_overwrite: true
  autoresume: false
  fsdp_config:
    sharding_strategy: "SHARD_GRAD_OP"
    state_dict_type: "full"
    mixed_precision: 'PURE'
    activation_checkpointing: true

Here is my dataset and dataloader code:

class CustomDataset(Array, Dataset):
    def __init__(self, 
                 data_path, 
                 feature_dim=64):
        self.feature_dim = feature_dim

        index_file = os.path.join(data_path, 'index.json')
        data = json.load(open(index_file))
        if data['version'] != 2:
            raise ValueError(f'Unsupported streaming data version: {data["version"]}. ' +
                             f'Expected version 2.')
        shards = []
        for info in data['shards']:
            shard = reader_from_json(data_path, None, info)
            shards.append(shard)

        self.shards = shards
        samples_per_shard = np.array([shard.samples for shard in shards], np.int64)
        self.length = samples_per_shard.sum()
        self.spanner = Spanner(samples_per_shard)

    def __len__(self):
        return self.length

    @property
    def size(self) -> int:
        """Get the size of the dataset in samples.

        Returns:
            int: Number of samples.
        """
        return self.length

    # def __getitem__(self, index):
    def get_item(self, index):
        shard_id, shard_sample_id = self.spanner[index]
        shard = self.shards[shard_id]
        sample = shard[shard_sample_id]
        out = {}
        if 'caption_latents' in sample:
            out['caption_latents'] = torch.from_numpy(
                np.frombuffer(sample['caption_latents'], dtype=np.float16).copy()).reshape(77, 768)

        if 'image_latents' in sample:
            out['image_latents'] = torch.from_numpy(np.frombuffer(sample['image_latents'],
                                                                  dtype=np.float16).copy()).reshape(4, self.feature_dim, self.feature_dim)
        return out

def build_custom_dataloader(
    batch_size: int,
    data_path: str,
    image_root: str = None,
    tokenizer_name_or_path: str = 'runwayml/stable-diffusion-v1-5',
    caption_drop_prob: float = 0.0,
    resize_size: int = 512,
    feature_dim: int = 64,
    drop_last: bool = True,
    shuffle: bool = True, #TODO pass shuffle to dataloader
    **dataloader_kwargs,
):
    print('Using precomputed features!!!')
    dataset = CustomDataset(data_path, feature_dim=feature_dim)
    drop_last = False

    if isinstance(dataset, IterableDataset):
        print('Using IterableDataset!!!')
        sampler = None
    else:
        print('Using Sampler!!!')
        sampler = dist.get_sampler(dataset, drop_last=drop_last, shuffle=shuffle)

    dataloader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        sampler=sampler,
        drop_last=drop_last,
        shuffle=shuffle if sampler is None else False,
        **dataloader_kwargs,
    )

    return dataloader

And I got this errors while finish epoch 0 and start epoch 1:

Error executing job with overrides: []
Traceback (most recent call last):
  File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1120, in _try_get_data
    data = self._data_queue.get(timeout=timeout)
  File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/multiprocessing/queues.py", line 113, in get
    if not self._poll(timeout):
  File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/multiprocessing/connection.py", line 262, in poll
    return self._poll(timeout)
  File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/multiprocessing/connection.py", line 429, in _poll
    r = wait([self], timeout)
  File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/multiprocessing/connection.py", line 936, in wait
    ready = selector.select(timeout)
  File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/selectors.py", line 416, in select
    fd_event_list = self._selector.poll(timeout)
  File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/torch/utils/data/_utils/signal_handling.py", line 66, in handler
    _error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 3007309) is killed by signal: Aborted. 

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/tungduongquang/workspace/mosaicml/image-generation/run.py", line 26, in <module>
    main()
  File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/hydra/main.py", line 94, in decorated_main
    _run_hydra(
  File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/hydra/_internal/utils.py", line 394, in _run_hydra
    _run_app(
  File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/hydra/_internal/utils.py", line 457, in _run_app
    run_and_report(
  File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/hydra/_internal/utils.py", line 223, in run_and_report
    raise ex
  File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/hydra/_internal/utils.py", line 220, in run_and_report
    return func()
  File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/hydra/_internal/utils.py", line 458, in <lambda>
    lambda: hydra.run(
  File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/hydra/_internal/hydra.py", line 132, in run
    _ = ret.return_value
  File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/hydra/core/utils.py", line 260, in return_value
    raise self._return_value
  File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/hydra/core/utils.py", line 186, in run_job
    ret.return_value = task_function(task_cfg)
  File "/home/tungduongquang/workspace/mosaicml/image-generation/run.py", line 22, in main
    return train(config)
  File "/home/tungduongquang/workspace/mosaicml/image-generation/diffusion/train.py", line 134, in train
    return eval_and_then_train()
  File "/home/tungduongquang/workspace/mosaicml/image-generation/diffusion/train.py", line 132, in eval_and_then_train
    trainer.fit()
  File "/home/tungduongquang/workspace/mosaicml/composer/composer/trainer/trainer.py", line 1796, in fit
    self._train_loop()
  File "/home/tungduongquang/workspace/mosaicml/composer/composer/trainer/trainer.py", line 1938, in _train_loop
    for batch_idx, self.state.batch in enumerate(self._iter_dataloader(TrainerMode.TRAIN)):
  File "/home/tungduongquang/workspace/mosaicml/composer/composer/trainer/trainer.py", line 2924, in _iter_dataloader
    batch = next(dataloader_iter)
  File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 628, in __next__
    data = self._next_data()
  File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1316, in _next_data
    idx, data = self._get_data()
  File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1282, in _get_data
    success, data = self._try_get_data()
  File "/home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1133, in _try_get_data
    raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e
RuntimeError: DataLoader worker (pid(s) 3007309) exited unexpectedly
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/torch/utils/data/datal │
│ oader.py:1120 in _try_get_data                                                                   │
│                                                                                                  │
│   1117 │   │   # Returns a 2-tuple:                                                              │
│   1118 │   │   #   (bool: whether successfully get data, any: data if successful else None)      │
│   1119 │   │   try:                                                                              │
│ ❱ 1120 │   │   │   data = self._data_queue.get(timeout=timeout)                                  │
│   1121 │   │   │   return (True, data)                                                           │
│   1122 │   │   except Exception as e:                                                            │
│   1123 │   │   │   # At timeout and error, we manually check whether any worker has              │
│                                                                                                  │
│ /home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/multiprocessing/queues.py:113 in get │
│                                                                                                  │
│   110 │   │   │   try:                                                                           │
│   111 │   │   │   │   if block:                                                                  │
│   112 │   │   │   │   │   timeout = deadline - time.monotonic()                                  │
│ ❱ 113 │   │   │   │   │   if not self._poll(timeout):                                            │
│   114 │   │   │   │   │   │   raise Empty                                                        │
│   115 │   │   │   │   elif not self._poll():                                                     │
│   116 │   │   │   │   │   raise Empty                                                            │
│                                                                                                  │
│ /home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/multiprocessing/connection.py:262 in │
│ poll                                                                                             │
│                                                                                                  │
│   259 │   │   """Whether there is any input available to be read"""                              │
│   260 │   │   self._check_closed()                                                               │
│   261 │   │   self._check_readable()                                                             │
│ ❱ 262 │   │   return self._poll(timeout)                                                         │
│   263 │                                                                                          │
│   264 │   def __enter__(self):                                                                   │
│   265 │   │   return self                                                                        │
│                                                                                                  │
│ /home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/multiprocessing/connection.py:429 in │
│ _poll                                                                                            │
│                                                                                                  │
│   426 │   │   return self._recv(size)                                                            │
│   427 │                                                                                          │
│   428 │   def _poll(self, timeout):                                                              │
│ ❱ 429 │   │   r = wait([self], timeout)                                                          │
│   430 │   │   return bool(r)                                                                     │
│   431                                                                                            │
│   432                                                                                            │
│                                                                                                  │
│ /home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/multiprocessing/connection.py:936 in │
│ wait                                                                                             │
│                                                                                                  │
│   933 │   │   │   │   deadline = time.monotonic() + timeout                                      │
│   934 │   │   │                                                                                  │
│   935 │   │   │   while True:                                                                    │
│ ❱ 936 │   │   │   │   ready = selector.select(timeout)                                           │
│   937 │   │   │   │   if ready:                                                                  │
│   938 │   │   │   │   │   return [key.fileobj for (key, events) in ready]                        │
│   939 │   │   │   │   else:                                                                      │
│                                                                                                  │
│ /home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/selectors.py:416 in select           │
│                                                                                                  │
│   413 │   │   │   timeout = math.ceil(timeout * 1e3)                                             │
│   414 │   │   ready = []                                                                         │
│   415 │   │   try:                                                                               │
│ ❱ 416 │   │   │   fd_event_list = self._selector.poll(timeout)                                   │
│   417 │   │   except InterruptedError:                                                           │
│   418 │   │   │   return ready                                                                   │
│   419 │   │   for fd, event in fd_event_list:                                                    │
│                                                                                                  │
│ /home/tungduongquang/miniconda3/envs/mosaicml/lib/python3.9/site-packages/torch/utils/data/_util │
│ s/signal_handling.py:66 in handler                                                               │
│                                                                                                  │
│   63 │   def handler(signum, frame):                                                             │
│   64 │   │   # This following call uses `waitid` with WNOHANG from C side. Therefore,            │
│   65 │   │   # Python can still get and update the process status successfully.                  │
│ ❱ 66 │   │   _error_if_any_worker_fails()                                                        │
│   67 │   │   if previous_handler is not None:                                                    │
│   68 │   │   │   assert callable(previous_handler)                                               │
│   69 │   │   │   previous_handler(signum, frame)                                                 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: DataLoader worker (pid 3007309) is killed by signal: Aborted.

Plz help!!!

Landanjs commented 1 year ago

Apologies for the delay, is the training dataset you are using in MDS streaming format? If so, you can use StreamingDataset class to load data even if it is stored locally.

hieuphung97 commented 1 year ago

@Landanjs I have the same issue when loading data in MDS streaming format. I have built my custom dataset purely using PyTorch's Dataset, not StreamingDataset, is that the problem?

Landanjs commented 1 year ago

@hieuphung97 is there a reason you aren't using a subclass of StreamingDataset to load your data? A custom dataset purely using PyTorch's Dataset may miss some logic to load MDS shards, so we recommend subclassing the StreamingDataset if possible.