crowsonkb / k-diffusion

Karras et al. (2022) diffusion models for PyTorch
MIT License
2.21k stars 371 forks source link

`--mixed-precision` doesnt work with img transformer 2 #91

Closed yoinked-h closed 6 months ago

yoinked-h commented 6 months ago

When trying to train with mixed precision (and natten), the pos embedding gets casted to fp32 and not bf16, causing an error later on in the attention.forward call

crowsonkb commented 6 months ago

I just fixed a similar sounding problem with the dtypes of the tensors being input to natten2dav(), which only occurred using very recent versions of NATTEN (commit: https://github.com/crowsonkb/k-diffusion/commit/6ab5146d4a5ef63901326489f31f1d8e7dd36b48), can you pull and check to see if this fixes your problem?

yoinked-h commented 6 months ago

this seems to have fixed it, ty!