apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.84k stars 3.48k forks source link

[Bug][DNNL][BYOC] Transform Pass From float32 to bfloat16 failed in some patterns #12763

Closed billishyahao closed 1 year ago

billishyahao commented 2 years ago

When we invoke

relay_mod = relay.transform.ToMixedPrecision('bfloat16')(relay_mod)

pass, we are supposed to get a well-transformed model which contains different kinds of operators with auto mix precision. However, we have seen some error messages.

The Relay type checker is unable to show the following types match.
In particular `Tensor[(64), float32]` does not match `Tensor[(64), bfloat16]`
The Relay type checker is unable to show the following types match.
In particular `Tensor[(64), float32]` does not match `Tensor[(64), bfloat16]`

After we print out the relay graph, we can see this error occurs because nn.layer_norm belongs to DEFAULT_NEVER_LIST but there is not cast operators which promoted its input from bfloat16 to float32.

There is a minimal script which can reproduce this issue:

def test_layernorm(run_module, dtype="float32"):
    x = relay.var("x", shape=(4, 3, 224, 224), dtype="float32")
    k_shape = (64, 3, 4, 4)
    kernel = relay.const(np.random.randint(0, 1, k_shape).astype(dtype))
    out = relay.nn.conv2d(
        x,
        kernel,
        channels=64,
        kernel_size=k_shape[2:4], 
        strides=[4, 4],
        padding=[0, 0, 0, 0]
    )
    # bias = relay.var("bias", shape=(64), dtype=dtype)
    bias = relay.const(np.random.randint(0, 1, (64,)).astype(dtype))
    out = relay.nn.bias_add(out, bias)

    out = relay.reshape(out, newshape=[4, 64, -1])

    out = relay.transpose(out, axes=[0, 2, 1])

    beta = relay.const(np.zeros((64)).astype("float32"))
    gamma = relay.const(np.ones((64)).astype("float32"))
    out = relay.nn.layer_norm(out, gamma=gamma, beta=beta)

    dic = {"x": (4, 3, 224, 224), "kernel": k_shape, "bias": (4, 64, 56, 56)}
    param_lst = ["kernel", "bias"]
    out = tvm.IRModule.from_expr(out)
    config = out, dic, param_lst
    run_and_verify_func(config, run_module=run_module, dtype=dtype)

I posted the code into my personal repo: https://github.com/apache/tvm/commit/fa1fb3ba1854c8c2c6fa813f95bb2454ff2f249c#diff-8b8dff568d8ee66e899bdc2b00d56db2a75e050c42b771fe38df17c4ca1ccdfb

yangulei commented 2 years ago

This bug will be fixed by https://github.com/apache/tvm/pull/12787, please help to review the PR, thanks.