tenstorrent / tt-tvm

TVM for Tenstorrent ASICs
Apache License 2.0
18 stars 6 forks source link

Support triu function for `tvm.relay.expr.Call` Inputs #2

Closed JushBJJ closed 1 month ago

JushBJJ commented 5 months ago

This PR is to address https://github.com/tenstorrent/tt-tvm/issues/3

Having a proper support for triangular upper when the inputs are a list of CallNodes will allow us to be a step closer to successfully implementing Qwen 1.5 (0.5B) (See https://github.com/tenstorrent/tt-buda-demos/issues/20).

Explanation

def triu(self, inputs, input_types):
    x = inputs[0]
    x_shape = _infer_shape(x)

+    if isinstance(inputs[0], tvm.relay.expr.Call):
+        return self.trilu(inputs, input_types, mode="triu")

    mask = np.triu(np.ones(x_shape), inputs[1]).astype(np.bool)
    mask = tvm.nd.array(mask)
    mask = tvm.relay.Constant(mask)

    zeros = np.zeros(x_shape).astype(_convert_tvm_to_np_dtype(input_types[0]))
    zeros = tvm.nd.array(zeros)
    zeros = tvm.relay.Constant(zeros)

    return _op.where(mask, x, zeros)

When compiling Qwen 1.5 (0.5B) (https://github.com/tenstorrent/tt-buda-demos/pull/37), one of its OP codes is aten::triu with its inputs containing nested functions of OP calls.

>> inputs[0]
CallNode(Op(add), [CallNode(Op(subtract), [CallNode(Op(subtract), [CallNode(Op(add), [Constant(256), Constant(0)], (nullptr), []), CallNode(Op(multiply), [Constant(1), Constant(256)], (nullptr), [])], (nullptr), []), CallNode(Op(multiply), [Constant(1), Constant(32768)], (nullptr), [])], (nullptr), []), Constant(1)], (nullptr), [])

>> type(inputs[0])
<tvm.relay.expr.Call>

self.trilu seems to be able to successfully handle these inputs when mode="upper" to do triangular upper operation.

Issues

1 Op(trilu) instead of Op(triu)

After doing self.trilu(inputs, input_types, mode="triu"), the resulting output would be:

CallNode(Op(trilu), [CallNode(Op(cast), [CallNode(Op(ones_like), [CallNode(Op(add),...

Is this a genuine issue??? Or can it be ignored for now?

2. NaN tensor values for Grayskull e75

When tested on @marty1885's e75, he ran into an error where his tensor values were NaN. But weirdly @JonathanALevine's e150 was able to successfully compile and run it until running into some errors later.

*This is a draft for now since this is just a workaround and not a proper fix yet.

JushBJJ commented 1 month ago

Closed as this is no longer an issue for Qwen specifically