zzzxxxttt / pytorch_DoReFaNet

A pytorch implementation of DoReFa-Net
MIT License
132 stars 26 forks source link

In utils/quant_dorefa.py, Line 46,Why does it multiply max_w ? #10

Open nbdyn opened 3 years ago

nbdyn commented 3 years ago

`class weight_quantize_fn(nn.Module): def init(self, w_bit): super(weight_quantize_fn, self).init() assert w_bit <= 8 or w_bit == 32 self.w_bit = w_bit self.uniform_q = uniform_quantize(k=w_bit)

def forward(self, x): if self.w_bit == 32: weight_q = x elif self.w_bit == 1: E = torch.mean(torch.abs(x)).detach() weight_q = self.uniform_q(x / E) * E else: weight = torch.tanh(x) max_w = torch.max(torch.abs(weight)).detach() weight = weight / 2 / max_w + 0.5 weight_q = max_w (2 self.uniform_q(weight) - 1) return weight_q`

**weight_q = max_w * (2 * self.uniform_q(weight) - 1)** In utils/quant_dorefa.py, Line 46,Why does it multiply max_w ? It seems that it is not need in the formula? image

HaFred commented 3 years ago

Hi @zzzxxxttt , I also noticed this. And if I follow this equation without max_w * (), the loss will be stable and seems like grad vanishing... So is it a typo in the paper itself? What do you think?

HaFred commented 3 years ago

With max_w * actually looks more reasonable to me, because with the division of max naturally we want the weight turns back to its original range

loveklmn commented 2 years ago

I also noticed this issue.

I prefer to remove the max_w item here.

Are there any conclusion?