Jack000 / glid-3-xl-stable

stable diffusion training
MIT License
290 stars 36 forks source link

cannot disable fp16 #9

Open XavierXiao opened 1 year ago

XavierXiao commented 1 year ago

It seems like there is a bug when setting use_fp16 = False, the log is

Traceback (most recent call last):
  File "scripts/image_train_stable.py", line 150, in <module>
    main()
  File "scripts/image_train_stable.py", line 78, in main
    TrainLoop(
  File "/home/yanq/Codes/glid-3-xl-stable/guided_diffusion/train_util.py", line 194, in run_loop
    self.run_step(batch, cond)
  File "/home/yanq/Codes/glid-3-xl-stable/guided_diffusion/train_util.py", line 208, in run_step
    self.forward_backward(batch, cond)
  File "/home/yanq/Codes/glid-3-xl-stable/guided_diffusion/train_util.py", line 236, in forward_backward
    losses = compute_losses()
  File "/home/yanq/Codes/glid-3-xl-stable/guided_diffusion/respace.py", line 96, in training_losses
    return super().training_losses(self._wrap_model(model), *args, **kwargs)
  File "/home/yanq/Codes/glid-3-xl-stable/guided_diffusion/gaussian_diffusion.py", line 1137, in training_losses
    model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)
  File "/home/yanq/Codes/glid-3-xl-stable/guided_diffusion/respace.py", line 133, in __call__
    return self.model(x, new_ts, **kwargs)
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 886, in forward
    output = self.module(*inputs[0], **kwargs[0])
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/yanq/Codes/glid-3-xl-stable/guided_diffusion/unet.py", line 882, in forward
    h = module(h, emb, context)
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/yanq/Codes/glid-3-xl-stable/guided_diffusion/unet.py", line 217, in forward
    x = layer(x, context)
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/yanq/Codes/glid-3-xl-stable/guided_diffusion/unet.py", line 188, in forward
    x = block(x, context=context)
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/yanq/Codes/glid-3-xl-stable/guided_diffusion/unet.py", line 140, in forward
    return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
  File "/home/yanq/Codes/glid-3-xl-stable/guided_diffusion/nn.py", line 162, in checkpoint
    return CheckpointFunction.apply(func, len(inputs), *args)
  File "/home/yanq/Codes/glid-3-xl-stable/guided_diffusion/nn.py", line 174, in forward
    output_tensors = ctx.run_function(*ctx.input_tensors)
  File "/home/yanq/Codes/glid-3-xl-stable/guided_diffusion/unet.py", line 144, in _forward
    x = self.attn2(self.norm2(x), context=context) + x
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/yanq/Codes/glid-3-xl-stable/guided_diffusion/unet.py", line 112, in forward
    sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
  File "/home/yanq/.conda/envs/ldm/lib/python3.8/site-packages/torch/functional.py", line 327, in einsum
    return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
RuntimeError: expected scalar type Float but found Half
Jack000 commented 1 year ago

should be fixed now

XavierXiao commented 1 year ago

Thanks! One interesting thing I found is fp 16 seems to use similar or even more GPU memory. Both fp16 and fp32 can train with bs=8 on a GPU with 48 GB ram, and when bs = 2, I even found fp32 takes less memory.