have a try, but got error raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.") RuntimeError: The shape of the 2D attn_mask is torch.Size([77, 77]), but should be (1, 1). #8
raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
RuntimeError: The shape of the 2D attn_mask is torch.Size([77, 77]), but should be (1, 1).
Did you try to run the code using the provided style image samples or just with pure text? If that doesn't work, then I guess most likely is due to incompatible library version. Try to use exact revision.
raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.") RuntimeError: The shape of the 2D attn_mask is torch.Size([77, 77]), but should be (1, 1).