SciSharp / TensorFlow.NET

.NET Standard bindings for Google's TensorFlow for developing, training and deploying Machine Learning models in C# and F#.
https://scisharp.github.io/tensorflow-net-docs
Apache License 2.0
3.17k stars 506 forks source link

[BUG Report]: LayerNormalization error #1216

Closed JustDooooIt closed 7 months ago

JustDooooIt commented 7 months ago

Description

var gradient = g.gradient(y, tensor);report error After removing LayerNormalization, the program runs normally. 021ca9f5ad43a44defa93ff05c25d4c7

Reproduction Steps

[TestMethod]
public void SimVPTest()
{
    var model = new TestModel1();
    using var g = tf.GradientTape();
    var tensor = tf.random.normal((1, 7, 8, 13));
    g.watch(tensor);
    var y = model.Apply(tensor);
    var gradient = g.gradient(y, tensor);
}
public class TestModel1 : Layer
{
    ILayer conv2d;
    ILayer norm;

    public TestModel1() : base(new())
    {
        conv2d = tf.keras.layers.Conv2D(16, 3, 1, "same");
        norm = tf.keras.layers.LayerNormalization(-1);
    }

    protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs optional_args = null)
    {
        var x = conv2d.Apply(inputs);
        x = norm.Apply(x);
        return x;
    }
}

Known Workarounds

No response

Configuration and Other Information

No response

JustDooooIt commented 7 months ago

emmm.Setting eplision to 1e-6