segmind / segmoe

Apache License 2.0
410 stars 25 forks source link

Why using negative prompt hidden states as gate weight? #19

Open IvanFei opened 9 months ago

IvanFei commented 9 months ago

hi,

ref: https://github.com/segmind/segmoe/blob/5fce95320f932aeb0991c9c0c31a3be72dbf7ce8/segmoe/main.py#L1300C13-L1300C26

 @torch.no_grad
  def get_hidden_states(self, model, positive, negative, average: bool = True):
      intermediate = {}
      self.cast_hook(model, intermediate)
      with torch.no_grad():
          _ = model(positive, negative_prompt=negative, num_inference_steps=25)
      hidden = {}
      for key in intermediate:
          hidden_states = intermediate[key][0][-1]  #### why using negative prompt as hidden states
          if average:
              # use average over sequence
              hidden_states = hidden_states.sum(dim=0) / hidden_states.shape[0]
          else:
              # take last value
              hidden_states = hidden_states[:-1]
          hidden[key] = hidden_states.to(self.device)
      del intermediate
      gc.collect()
      torch.cuda.empty_cache()
      return hidden
Warlord-K commented 9 months ago

We take the negative prompt into account since many finetunes suggest specific negative prompts they should be used with. The idea being that when the router encounters a similar prompt and negative prompt it will route to that specific model's layer. Though we have to do some ablation tests and see how much the inclusion of these negative prompts affects the final SegMoE.

IvanFei commented 9 months ago

thank you for kind reply.

Why not using hidden states of positive hidden states? e.g. hidden_states = intermediate[key][0][0] Here when using Classifier-free Guidance, positive and negative prompt would form a batch to infer.

Here‘s another question i'd like to ask:

  1. i saw SegMoE using hidden states activation at last step of diffusion process. Have you tried average all hidden states during diffusion or anything else?