Closed justinxzhao closed 1 year ago
@justinxzhao @w4nderlust Initial observation...all the failed tests are calling out a particular norm layer weights/bias parameter not being updated with pytorch nightly. The same norm layers appear to be updated when using a GA version of PyTorch.
I still need to dig in more to determine root cause.
Looks like torch.nn.LayerNorm
in nightly behaves differently. In the GA version, the output of this layer has "non-zero" values, albeit very small, on the order of 1e-6.
In the nightly version the output is zero. A result, I believe this results in a zero gradient and no parameter updates.
Anyway, going further down the rabbit hole.
@justinxzhao @w4nderlust
Here is a MWE of the zero output from torch.nn.LayerNorm
import torch
from torch import nn, functional as F
if __name__ == "__main__":
print(f"Torch version: {torch.__version__}")
torch.manual_seed(42)
inputs = torch.randn(16, 1, dtype=torch.float32)
layer_norm = nn.LayerNorm(1)
outputs = layer_norm(inputs)
print(f"layer norm output:\n{outputs}")
When run with torch-nightly
, this is what I see
Torch version: 2.2.0.dev20230908+cpu
layer norm output:
tensor([[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.]], grad_fn=<NativeLayerNormBackward0>)
When I run the same program with torch 2.0.0
, this is what I see
Torch version: 2.0.0+cpu
layer norm output:
tensor([[-2.7311e-05],
[-1.3996e-05],
[-1.4542e-05],
[-8.0787e-06],
[ 5.3804e-06],
[ 1.1374e-05],
[ 1.2255e-08],
[ 5.2535e-06],
[-1.9136e-06],
[ 1.2105e-05],
[-3.8046e-06],
[-1.3119e-05],
[ 6.2030e-06],
[-2.4129e-06],
[ 4.1681e-06],
[ 7.6920e-07]], grad_fn=<NativeLayerNormBackward0>)
I'm leaning toward there may be an issue with nn.LayerNorm
in the nightly
version. If you concur, I'll open an issue with the torch
project.
Nice investigation @jimthompson5802! The differing behavior for layer norm is suspicious. Opening an issue on the torch project to check if this is intended SGTM -- thanks!
Digging a littler deeper...there appears to be another complicating factor. If the shape of the input tensor is (batch_size, 1), the layernorm outputs zero tensor. OTOH if the second dimension is 2 or greater, the output is non-zero. Here are example outputs where the second dimension is 2 or greater:
Torch version: 2.2.0.dev20230909+cpu
layer norm output:
tensor([[ 0.9999, -0.9999],
[ 1.0000, -1.0000],
[ 1.0000, -1.0000],
[ 1.0000, -1.0000],
[-1.0000, 1.0000],
[ 1.0000, -1.0000],
[-0.9993, 0.9993],
[-1.0000, 1.0000],
[ 1.0000, -1.0000],
[-1.0000, 1.0000],
[-1.0000, 1.0000],
[-1.0000, 1.0000],
[-0.9392, 0.9392],
[-1.0000, 1.0000],
[-0.9997, 0.9997],
[-1.0000, 1.0000]], grad_fn=<NativeLayerNormBackward0>)
Torch version: 2.2.0.dev20230909+cpu
layer norm output:
tensor([[ 0.8716, 0.5928, 0.2209, -1.6853],
[ 1.3440, -0.7473, 0.5552, -1.1519],
[-0.4622, 1.6423, -0.1469, -1.0332],
[-0.6401, -0.3735, -0.7050, 1.7186],
[ 1.5784, -0.6330, -1.0476, 0.1023],
[-1.6203, 0.4198, 0.1115, 1.0889],
[ 0.4952, 0.5527, -1.7281, 0.6801],
[-0.7452, -0.1393, -0.7894, 1.6739],
[-1.0156, -0.5789, -0.0280, 1.6225],
[ 0.9789, -0.5945, 0.9511, -1.3355],
[-1.1176, 1.6082, -0.3940, -0.0965],
[-0.7608, 1.6873, -0.7323, -0.1942],
[-1.1710, -0.7253, 0.5579, 1.3384],
[-0.3374, 1.7132, -0.7355, -0.6403],
[-1.6057, 0.2972, 0.1657, 1.1428],
[-0.4859, 1.1295, -1.3897, 0.7461]],
grad_fn=<NativeLayerNormBackward0>)
To minimize the computing resource requirements, the unit test only generates a single column for the test. Maybe the short-term fix is to ensure at least two columns for the synthetic data.
However, it may still make sense to open a issue with PyTorch for the corner case of a single input feature.
In preparing to open an issue with Torch, I found this issue and its related PR
This closed about two weeks ago and is related to computing the LayerNorm output. This might be the cause for our CI Test issue.
I'll open an issue with Torch to see what they say on the corner case of a single column.
In the meantime, I'll update our unit test to to generate at least two columns to see if it will resolve our CI test issue.
Opened issue with the PyTorch project: https://github.com/pytorch/pytorch/issues/108956 regarding all zero output from nn.LayerNorm
for single column batch.
Please note this issue exist in nightly only. Not Pytorch Release Candiate for 2.1. Test result:
Torch version: 2.1.0
layer norm output:
tensor([[-2.7311e-05],
[-1.3996e-05],
[-1.4542e-05],
[-8.0787e-06],
[ 5.3804e-06],
[ 1.1374e-05],
[ 1.2255e-08],
[ 5.2535e-06],
[-1.9136e-06],
[ 1.2105e-05],
[-3.8046e-06],
[-1.3119e-05],
[ 6.2030e-06],
[-2.4129e-06],
[ 4.1681e-06],
[ 7.6920e-07]], grad_fn=<NativeLayerNormBackward0>)
Sample run
cc: @jimthompson5802