hpcaitech / ColossalAI

Making large AI models cheaper, faster and more accessible
https://www.colossalai.org
Apache License 2.0
38.74k stars 4.34k forks source link

[BUG]: Stable diffusion training thows"RuntimeError: Only float32 type is supported for now" when xformers is installed #3819

Closed BugFreeee closed 1 year ago

BugFreeee commented 1 year ago

Was follwing the env setup in stable diffusion readme. ENV: Tried both conda and docker. Both failed. CMD to reproduce: CUDA_VISIBLE_DEVICES=1 bash train_colossalai.sh. At this point everying works fine. But after instaling xformers using cmd "pip install xformers==0.0.12" , the script throws error as follwoings:

BugFreeee commented 1 year ago
(ldm) root@46397fb48a9b:/workspace/examples/images/diffusion# CUDA_VISIBLE_DEVICES=1 bash train_colossalai.sh
WARNING:root:Triton is not available, some optimizations will not be enabled.
Error No module named 'triton'
Using base config ['configs/Teyvat/train_colossalai_teyvat.yaml']
Global seed set to 23
Using ckpt_path = diffuser_root_dir/512-base-ema.ckpt
LatentDiffusion: Running in v-prediction mode
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 1024 and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 1024 and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 1024 and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 1024 and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is 1024 and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is 1024 and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is 1024 and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is 1024 and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is 1024 and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is 1024 and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 1024 and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 1024 and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 1024 and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 1024 and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 1024 and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 1024 and using 5 heads.
building MemoryEfficientAttnBlock with 512 in_channels...
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
building MemoryEfficientAttnBlock with 512 in_channels...
/root/anaconda3/envs/ldm/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/accelerator_connector.py:578: LightningDeprecationWarning: The Trainer argument `auto_select_gpus` has been deprecated in v1.9.0 and will be removed in v2.0.0. Please use the function `lightning.pytorch.accelerators.find_usable_cuda_devices` instead.
  rank_zero_deprecation(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
WARNING:datasets.builder:Found cached dataset teyvat (/root/.cache/huggingface/datasets/Fazzie___teyvat/train/0.0.0/62e3cc07a1a94bcb7c0d02f703087023dd935272664b2da5525b893724f24701)
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 639.57it/s]
WARNING:datasets.builder:Found cached dataset teyvat (/root/.cache/huggingface/datasets/Fazzie___teyvat/train/0.0.0/62e3cc07a1a94bcb7c0d02f703087023dd935272664b2da5525b893724f24701)
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1198.03it/s]
train, Dataset, 234
accumulate_grad_batches = 1
Setting learning rate to 1.60e-03 = 1 (accumulate_grad_batches) * 1 (num_gpus) * 16 (batchsize) * 1.00e-04 (base_lr)
/root/anaconda3/envs/ldm/lib/python3.9/site-packages/lightning/pytorch/trainer/configuration_validator.py:108: PossibleUserWarning: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
  rank_zero_warn(
WARNING:datasets.builder:Found cached dataset teyvat (/root/.cache/huggingface/datasets/Fazzie___teyvat/train/0.0.0/62e3cc07a1a94bcb7c0d02f703087023dd935272664b2da5525b893724f24701)
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1301.77it/s]
Missing logger folder: /tmp/2023-05-23T16-19-09_train_colossalai_teyvat/diff_tb
WARNING:datasets.builder:Found cached dataset teyvat (/root/.cache/huggingface/datasets/Fazzie___teyvat/train/0.0.0/62e3cc07a1a94bcb7c0d02f703087023dd935272664b2da5525b893724f24701)
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1011.16it/s]
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 1024 and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 1024 and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 1024 and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 1024 and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is 1024 and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is 1024 and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is 1024 and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is 1024 and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is 1024 and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is 1024 and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 1024 and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 1024 and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 1024 and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 1024 and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 1024 and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 1024 and using 5 heads.
DiffusionWrapper has 865.91 M params.
=========================================================================================
No pre-built kernel is found, build and load the cpu_adam kernel during runtime now
=========================================================================================
Emitting ninja build file /root/.cache/colossalai/torch_extensions/torch1.12_cu11.3/build.ninja...
Building extension module cpu_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module cpu_adam...
Time to load cpu_adam op: 1.5531184673309326 seconds
=========================================================================================
No pre-built kernel is found, build and load the fused_optim kernel during runtime now
=========================================================================================
Detected CUDA files, patching ldflags
Emitting ninja build file /root/.cache/colossalai/torch_extensions/torch1.12_cu11.3/build.ninja...
Building extension module fused_optim...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module fused_optim...
Time to load fused_optim op: 1.2574388980865479 seconds
searching chunk configuration is completed in 0.22 s.
used number: 825.80 MB, wasted number: 6.52 MB
total wasted percentage is 0.78%
Project config
model:
  base_learning_rate: 0.0001
  params:
    parameterization: v
    linear_start: 0.00085
    linear_end: 0.012
    num_timesteps_cond: 1
    ckpt: diffuser_root_dir/512-base-ema.ckpt
    log_every_t: 200
    timesteps: 1000
    first_stage_key: image
    cond_stage_key: txt
    image_size: 64
    channels: 4
    cond_stage_trainable: false
    conditioning_key: crossattn
    monitor: val/loss_simple_ema
    scale_factor: 0.18215
    use_ema: false
    scheduler_config:
      warm_up_steps:
      - 1
      cycle_lengths:
      - 10000000000000
      f_start:
      - 1.0e-06
      f_max:
      - 0.0001
      f_min:
      - 1.0e-10
    unet_config:
      use_checkpoint: true
      use_fp16: true
      image_size: 32
      in_channels: 4
      out_channels: 4
      model_channels: 320
      attention_resolutions:
      - 4
      - 2
      - 1
      num_res_blocks: 2
      channel_mult:
      - 1
      - 2
      - 4
      - 4
      num_head_channels: 64
      use_spatial_transformer: true
      use_linear_in_transformer: true
      transformer_depth: 1
      context_dim: 1024
      legacy: false
    first_stage_config:
      embed_dim: 4
      monitor: val/rec_loss
      ddconfig:
        double_z: true
        z_channels: 4
        resolution: 256
        in_channels: 3
        out_ch: 3
        ch: 128
        ch_mult:
        - 1
        - 2
        - 4
        - 4
        num_res_blocks: 2
        attn_resolutions: []
        dropout: 0.0
      lossconfig: null
    cond_stage_config:
      freeze: true
      layer: penultimate
    use_fp16: true
data:
  batch_size: 16
  num_workers: 4
  train:
    target: ldm.data.teyvat.hf_dataset
    params:
      path: Fazzie/Teyvat
      image_transforms:
      - target: torchvision.transforms.Resize
        params:
          size: 512
      - target: torchvision.transforms.RandomCrop
        params:
          size: 512
      - target: torchvision.transforms.RandomHorizontalFlip

Lightning config
trainer:
  accelerator: gpu
  devices: 1
  log_gpu_memory: all
  max_epochs: 10
  precision: 16
  auto_select_gpus: false
  strategy:
    use_chunk: true
    enable_distributed_storage: true
    placement_policy: cuda
    force_outputs_fp32: true
    min_chunk_size: 64
  log_every_n_steps: 2
  logger: true
  default_root_dir: /tmp/diff_log/
  accumulate_grad_batches: 1
logger_config:
  wandb:
    name: nowname
    save_dir: /tmp/diff_log/
    offline: opt.debug
    id: nowname

/root/anaconda3/envs/ldm/lib/python3.9/site-packages/lightning/pytorch/loggers/tensorboard.py:188: UserWarning: Could not log computational graph to TensorBoard: The `model.example_input_array` attribute is not set or `input_array` was not given.
  rank_zero_warn(
Epoch 0:   0%|                                                                                                                                                                                                            | 0/15 [00:00<?, ?it/s]Summoning checkpoint.

Traceback (most recent call last):
  File "/workspace/examples/images/diffusion/main.py", line 845, in <module>
    trainer.fit(model, data)
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py", line 608, in fit
    call._call_and_handle_interrupt(
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py", line 36, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 88, in launch
    return function(*args, **kwargs)
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py", line 650, in _fit_impl
    self._run(model, ckpt_path=self.ckpt_path)
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py", line 1103, in _run
    results = self._run_stage()
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py", line 1182, in _run_stage
    self._run_train()
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py", line 1205, in _run_train
    self.fit_loop.run()
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/lightning/pytorch/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py", line 267, in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/lightning/pytorch/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/lightning/pytorch/loops/epoch/training_epoch_loop.py", line 213, in advance
    batch_output = self.batch_loop.run(kwargs)
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/lightning/pytorch/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/lightning/pytorch/loops/batch/training_batch_loop.py", line 88, in advance
    outputs = self.optimizer_loop.run(optimizers, kwargs)
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/lightning/pytorch/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/lightning/pytorch/loops/optimization/optimizer_loop.py", line 202, in advance
    result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position])
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/lightning/pytorch/loops/optimization/optimizer_loop.py", line 249, in _run_optimization
    self._optimizer_step(optimizer, opt_idx, kwargs.get("batch_idx", 0), closure)
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/lightning/pytorch/loops/optimization/optimizer_loop.py", line 370, in _optimizer_step
    self.trainer._call_lightning_module_hook(
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py", line 1347, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/lightning/pytorch/core/module.py", line 1744, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/lightning/pytorch/core/optimizer.py", line 169, in step
    step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/lightning/pytorch/strategies/colossalai.py", line 411, in optimizer_step
    return self.precision_plugin.optimizer_step(
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/lightning/pytorch/plugins/precision/colossalai.py", line 73, in optimizer_step
    closure_result = closure()
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/lightning/pytorch/loops/optimization/optimizer_loop.py", line 149, in __call__
    self._result = self.closure(*args, **kwargs)
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/lightning/pytorch/loops/optimization/optimizer_loop.py", line 135, in closure
    step_output = self._step_fn()
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/lightning/pytorch/loops/optimization/optimizer_loop.py", line 419, in _training_step
    training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py", line 1485, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/lightning/pytorch/strategies/ddp.py", line 351, in training_step
    return self.model(*args, **kwargs)
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/colossalai/nn/parallel/data_parallel.py", line 282, in forward
    outputs = self.module(*args, **kwargs)
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/lightning/pytorch/overrides/base.py", line 98, in forward
    output = self._forward_module.training_step(*inputs, **kwargs)
  File "/workspace/examples/images/diffusion/ldm/models/diffusion/ddpm.py", line 477, in training_step
    loss, loss_dict = self.shared_step(batch)
  File "/workspace/examples/images/diffusion/ldm/models/diffusion/ddpm.py", line 925, in shared_step
    x, c = self.get_input(batch, self.first_stage_key)
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/examples/images/diffusion/ldm/models/diffusion/ddpm.py", line 862, in get_input
    encoder_posterior = self.encode_first_stage(x)
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/examples/images/diffusion/ldm/models/diffusion/ddpm.py", line 922, in encode_first_stage
    return self.first_stage_model.encode(x)
  File "/workspace/examples/images/diffusion/ldm/models/autoencoder.py", line 84, in encode
    h = self.encoder(x)
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/workspace/examples/images/diffusion/ldm/modules/diffusionmodules/model.py", line 499, in forward
    h = self.mid.attn_1(h)
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/workspace/examples/images/diffusion/ldm/modules/diffusionmodules/model.py", line 210, in forward
    out = xformers.ops.memory_efficient_attention(q, k, v)
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/xformers/ops.py", line 58, in memory_efficient_attention
    return torch.ops.xformers.efficient_attention(query, key, value, False)[0]
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/torch/_ops.py", line 143, in __call__
    return self._op(*args, **kwargs or {})
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/colossalai/tensor/colo_tensor.py", line 184, in __torch_function__
    ret = func(*args, **kwargs)
  File "/root/anaconda3/envs/ldm/lib/python3.9/site-packages/torch/_ops.py", line 143, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: Only float32 type is supported for now
BugFreeee commented 1 year ago

Any help is appreciated! Also the GPU RAM usage is ~11gb for me on 1 A100 GPU, at bs16. The table on readme shows only 5.6 gb consumption. Is this a result of not enabling xformers?