Open ehdgus8077 opened 11 months ago
cc. @harrydrippin
Thank you @ehdgus8077 for posting your issue. We have not validated support of IP-Adapter. You can try the '--enable-saturate-infinity' compiler option, which sometimes is able to resolve NaN issues in model compilation. In the meantime, we will try to reproduce your issue, and will update you when we have further information.
Problem
I am currently using the Stable Diffusion 1.5 model with IP-Adapter module. The IP-Adapter is a module designed to generate images similar to reference images, incorporating a separate image encoder and Cross Attention module.
However, when the IP-Adapter module is used, it seems that the output of the Unet module in the Stable Diffusion model starts producing NaN values from the first step.
Examination
I have thoroughly examined the issue and found the following:
Cases with no issues:
The problem currently arises only when using the IP-Adapter and creating images with rectangular dimensions, such as 768x512.
I have followed the model compilation code from this link and changed the precision to Bfloat16 for all modules except VAE.
I tried to debug by checking almost every line of this file, and found the incomprehensible fix: in the
IPAttnProcessor
, I found that by simply multiplying the ip_hidden_states by 1.0001, model starts to function normally.Code
class IPAttnProcessor(nn.Module): def __init__(self, hidden_size, cross_attention_dim=None, scale=0.3, num_tokens=4): super().__init__() self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.scale = scale self.num_tokens = num_tokens self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states else: # get encoder_hidden_states, ip_hidden_states end_pos = encoder_hidden_states.shape[1] - self.num_tokens encoder_hidden_states, ip_hidden_states = ( encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :], ) if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # for ip-adapter # There is a NaN issue when ip_hidden_states is used without modifications. ip_hidden_states = ip_hidden_states * 1.0001 ip_key = self.to_k_ip(ip_hidden_states) ip_value = self.to_v_ip(ip_hidden_states) ip_key = attn.head_to_batch_dim(ip_key) ip_value = attn.head_to_batch_dim(ip_value) ip_attention_probs = attn.get_attention_scores(query, ip_key, None) ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) hidden_states = hidden_states + self.scale * ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_statesThank you for your help!