Closed jwooldridge234 closed 10 months ago
Hi @jwooldridge234 , I don't see obvious errors. Which version of Diffusers are you using? The provided updated pipeline is only tested with Diffusers 0.21.4. Higher than this version may not work.
@YumengLi007 Yeah, I just checked, and I'm running diffusers 0.21.4. EDIT: Just tested running it on CPU and it finishes with a loss of -0.08085955679416656, so it's definitely an MPS backend issue.
Good to know. Thanks for digging into the issue.
Did a bit more digging and found the exact line where it fails (129 of the pipeline, in DivideBindAttnProcessor): attention_probs = attn.get_attention_scores(query, key, attention_mask)
Looks like there's a bug with diffusers.models.attention_processor on mps. I'll raise this with them directly, and close this issue. Thanks for being so responsive!
Hi @jwooldridge234 , just saw this, not sure if switching to torch.float32 could fix the issue 😅
Not certain if this is a mac-specific issue, as I don't have a different system I can test on. I'm running pipeline_divide_and_bind_latest.py, and after the first step of the (while target_indicator < target) loop in _perform_iterative_refinement_step it starts producing tensors with NaN values. Ultimately, loss finishes with a value of -inf.
Tried to do some debugging, but I can't figure out the exact part where it fails. It's not the loss calculation (although that complains about an error I've listed below), since the attention store in AttentionStore also returns NaN tensors after the first step and that's what leads to the loss calculation failure.
Here's my code I'm using to call the pipeline:
Loss error: UserWarning: Using a target size (torch.Size([15, 16])) that is different to the input size (torch.Size([1, 16])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
I'm running the latest pytorch nightly (2.3.0.dev20240117) but it also failed on the latest stable.
Please let me know if you see any errors in my implementation.
Many thanks!