Closed TobiasLee closed 4 years ago
There is an implicit float16 -> float32 conversion in the original code if net weights are float16
p.data = w + torch.Tensor(d).type(type(w))
thus results in inaccurate loss computation. The following code can avoid this problem since it take the exact w.dtype when doing the conversion
w.dtype
p.data = w + torch.Tensor(d).type_as(w.dtype)
need further test
There is an implicit float16 -> float32 conversion in the original code if net weights are float16
thus results in inaccurate loss computation. The following code can avoid this problem since it take the exact
w.dtype
when doing the conversion