def forward(self, x, condition):
scale_shift_params = self.adaptive_norm_layer(condition).chunk(6, dim=1)
(pre_attn_shift, pre_attn_scale, post_attn_scale,
pre_mlp_shift, pre_mlp_scale, post_mlp_scale) = scale_shift_params
out = x
attn_norm_output = (self.att_norm(out) * (1 + pre_attn_scale.unsqueeze(1))
+ pre_attn_shift.unsqueeze(1))
out = out + post_attn_scale.unsqueeze(1) * self.attn_block(attn_norm_output)
mlp_norm_output = (self.ff_norm(out) * (1 + pre_mlp_scale.unsqueeze(1)) +
pre_mlp_shift.unsqueeze(1))
# wrong
out = out + post_mlp_scale.unsqueeze(1) * self.attn_block(mlp_norm_output)
return out
according to the paper and the video you make, the code in the last but two should be
out = out + post_mlp_scale.unsqueeze(1) * self.mlp_block(mlp_norm_output)
finally, reeeeeally appreciate your video and your work !!!!
according to the paper and the video you make, the code in the last but two should be
out = out + post_mlp_scale.unsqueeze(1) * self.mlp_block(mlp_norm_output)
finally, reeeeeally appreciate your video and your work !!!!