Open wojtke opened 1 month ago
Thank you for reporting this issue!
I see that the error is caused by some assert assuming that
aten__native_batch_norm_legit_no_stats
is in training mode, even if it is not.
It's actually the other way around, the assert is asserting training == false
, which is being violated, i.e. training is equal to true. Notice in the exported graph, the True
argument to executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_stats
:
aten__native_batch_norm_legit_no_stats = executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_stats(aten_view_copy_default, None, None, True, 0.1, 1e-05); aten_view_copy_default = None
The question is: Why did training end up being equal to True, if you are just doing eval?
If you run this under
'''with torch.no_grad():'''
does it still happen?
If you run this under
'''with torch.no_grad():'''
does it still happen?
The exact same thing happens.
It seems that batch norms with track_running_stats=False
or instance norm layers decompose to aten__native_batch_norm_legit_no_stats
, which doesnāt use running mean or variance parameters. Because of this, training=True
is always set. This doesnāt mean they are in training modeājust that they calculate mean and variance on the fly, similar to LayerNorm. The naming might be confusing, but I guess thatās just how it works. The decomposition that supports the point linked here.
I think it would be best to remove the first assert at https://github.com/pytorch/executorch/blob/caadd81e65bf1240479250124814bc625a13b50c/kernels/portable/cpu/op_native_batch_norm.cpp#L173-L185 Right now, the asserts contradict each other.
@wojtke Oh, like LayerNorm, I see. This makes some sense now. I am the one that added those contradicting asserts. I implemented this op for completeness sake only. I thought the op wasn't meant to be reached during eval time, only during training, because of the assert here Ok, will work on this soon. Thank you for catching and figuring this out!
š Describe the bug
I found that
torch.nn.InstanceNorm2d
layer, which by default does not track running mean and var, is failing at runtime. It successfully goes through the lowering process with no errors.I see that the error is caused by some assert assuming that
aten__native_batch_norm_legit_no_stats
is in training mode, even if it is not. The same goes fortorch.nn.BatchNorm2d
withtrack_running_stats=False
.Minimal case to reproduce:
The exported program at edge dialect step:
I built the runner using:
Then when I run the lowered model:
Versions