Open githubpeichao opened 2 weeks ago
i found the torch.nn._BatchNorm
op's forward function:
class _BatchNorm(_NormBase):
def __init__(
self,
num_features: int,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = True,
track_running_stats: bool = True,
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__(
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
)
def forward(self, input: Tensor) -> Tensor:
self._check_input_dim(input)
...
Decide whether the mini-batch stats should be used for normalization rather than the buffers.
Mini-batch stats are used in training mode, and in eval mode when buffers are None.
if self.training:
bn_training = True
else:
bn_training = (self.running_mean is None) and (self.running_var is None)
Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
used for normalization (i.e. in eval mode when buffers are not None).
"""
return F.batch_norm(
input,
# If buffers are not to be tracked, ensure that they won't be updated
self.running_mean
if not self.training or self.track_running_stats
else None,
self.running_var if not self.training or self.track_running_stats else None,
self.weight,
self.bias,
bn_training,
exponential_average_factor,
self.eps,
)
this line bn_training = (self.running_mean is None) and (self.running_var is None)
may cause above problem.
it seems aten::instance_norm
without running_mean
weight and running_var
weight can be regarded traning mode.
I always thought the problem was with torch.nn.BatchNorm2d
op, but now I think it was the problem with aten::instance_norm
.
Seems like a duplicate of https://github.com/pytorch/executorch/issues/4669
Seems like a duplicate of #4669
yes, it seems like same problem.
I have replaced the aten::instance_norm
with BatchNorm1d
,and it can run executor_runner
successfully.
But this is only a temporary solution
🐛 Describe the bug
it's happened when i run executor_runner with my pte model. it seems failed on
torch.nn.BatchNorm2d
op. here is the tracks:i am sure that i have called model.eval() functions before export model. here is my code for exporting model:
Versions