Open nbdyn opened 3 years ago
Hi @zzzxxxttt , I also noticed this. And if I follow this equation without max_w * ()
, the loss will be stable and seems like grad vanishing... So is it a typo in the paper itself? What do you think?
With max_w *
actually looks more reasonable to me, because with the division of max naturally we want the weight turns back to its original range
I also noticed this issue.
I prefer to remove the max_w
item here.
Are there any conclusion?
`class weight_quantize_fn(nn.Module): def init(self, w_bit): super(weight_quantize_fn, self).init() assert w_bit <= 8 or w_bit == 32 self.w_bit = w_bit self.uniform_q = uniform_quantize(k=w_bit)
def forward(self, x): if self.w_bit == 32: weight_q = x elif self.w_bit == 1: E = torch.mean(torch.abs(x)).detach() weight_q = self.uniform_q(x / E) * E else: weight = torch.tanh(x) max_w = torch.max(torch.abs(weight)).detach() weight = weight / 2 / max_w + 0.5 weight_q = max_w (2 self.uniform_q(weight) - 1) return weight_q`
**weight_q = max_w * (2 * self.uniform_q(weight) - 1)**
In utils/quant_dorefa.py, Line 46,Why does it multiply max_w ? It seems that it is not need in the formula?