aws-neuron / aws-neuron-sdk

Powering AWS purpose-built machine learning chips. Blazing fast and cost effective, natively integrated into PyTorch and TensorFlow and integrated with your favorite AWS services
https://aws.amazon.com/machine-learning/neuron/
Other
462 stars 154 forks source link

[inferentia2] Model produces `NaN` when using Stable Diffusion with IP-Adapter #804

Open ehdgus8077 opened 11 months ago

ehdgus8077 commented 11 months ago

Problem

I am currently using the Stable Diffusion 1.5 model with IP-Adapter module. The IP-Adapter is a module designed to generate images similar to reference images, incorporating a separate image encoder and Cross Attention module.

However, when the IP-Adapter module is used, it seems that the output of the Unet module in the Stable Diffusion model starts producing NaN values from the first step.

Examination

I have thoroughly examined the issue and found the following:

Cases with no issues:

The problem currently arises only when using the IP-Adapter and creating images with rectangular dimensions, such as 768x512.

I have followed the model compilation code from this link and changed the precision to Bfloat16 for all modules except VAE.

I tried to debug by checking almost every line of this file, and found the incomprehensible fix: in the IPAttnProcessor, I found that by simply multiplying the ip_hidden_states by 1.0001, model starts to function normally.

Code class IPAttnProcessor(nn.Module): def __init__(self, hidden_size, cross_attention_dim=None, scale=0.3, num_tokens=4): super().__init__() self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.scale = scale self.num_tokens = num_tokens self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states else: # get encoder_hidden_states, ip_hidden_states end_pos = encoder_hidden_states.shape[1] - self.num_tokens encoder_hidden_states, ip_hidden_states = ( encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :], ) if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # for ip-adapter # There is a NaN issue when ip_hidden_states is used without modifications. ip_hidden_states = ip_hidden_states * 1.0001 ip_key = self.to_k_ip(ip_hidden_states) ip_value = self.to_v_ip(ip_hidden_states) ip_key = attn.head_to_batch_dim(ip_key) ip_value = attn.head_to_batch_dim(ip_value) ip_attention_probs = attn.get_attention_scores(query, ip_key, None) ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) hidden_states = hidden_states + self.scale * ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states

Thank you for your help!

ehdgus8077 commented 11 months ago

cc. @harrydrippin

jeffhataws commented 11 months ago

Thank you @ehdgus8077 for posting your issue. We have not validated support of IP-Adapter. You can try the '--enable-saturate-infinity' compiler option, which sometimes is able to resolve NaN issues in model compilation. In the meantime, we will try to reproduce your issue, and will update you when we have further information.