Closed yoinked-h closed 6 months ago
I just fixed a similar sounding problem with the dtypes of the tensors being input to natten2dav(), which only occurred using very recent versions of NATTEN (commit: https://github.com/crowsonkb/k-diffusion/commit/6ab5146d4a5ef63901326489f31f1d8e7dd36b48), can you pull and check to see if this fixes your problem?
this seems to have fixed it, ty!
When trying to train with mixed precision (and natten), the pos embedding gets casted to fp32 and not bf16, causing an error later on in the attention.forward call