ali-vilab / Cones-V2

[NeurIPS 2023] Official implementations for paper: Customizable Image Synthesis with Multiple Subjects
MIT License
513 stars 18 forks source link

RuntimeError: mat1 and mat2 must have the same dtype #4

Closed Kyfafyd closed 1 year ago

Kyfafyd commented 1 year ago

Thanks for your nice work! During running the training, I meet thre following error. Could you please help me figure it out? I have strictly follow the environment setup in the README.

Traceback (most recent call last):
  File "/home/zwang/server/Cones-V2/train_cones2.py", line 836, in <module>
    main(args)
  File "/home/zwang/server/Cones-V2/train_cones2.py", line 745, in main
    model_pred = unet(noisy_latents.float(), timesteps, encoder_hidden_states.float()).sample
  File "/home/zwang/anaconda3/envs/cones/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/zwang/anaconda3/envs/cones/lib/python3.9/site-packages/diffusers/models/unet_2d_condition.py", line 574, in forward
    sample = self.conv_in(sample)
  File "/home/zwang/anaconda3/envs/cones/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/zwang/anaconda3/envs/cones/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 463, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/zwang/anaconda3/envs/cones/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (float) and bias type (c10::Half) should be the same
Johanan528 commented 1 year ago

I haven't encountered a similar issue. Perhaps try resetting the accelerate config manually ?

"Do you wish to use FP16 or BF16 (mixed precision)? [NO/fp16/bf16]: No"

Kyfafyd commented 1 year ago

Thanks for response! So the training can only be run under fp32? If it is applicable for fp16 for saving memory and speedup? As my GPU only has 24G memory, running under fp32 will cause OOM error