tenstorrent / tt-tvm

TVM for Tenstorrent ASICs
Apache License 2.0
18 stars 6 forks source link

Refactor Where operator to avoid generation of NaN tensors #15

Closed kamalrajkannan78 closed 1 month ago

kamalrajkannan78 commented 1 month ago

In pybuda where op is decomposed in a below way

elif type == "where":

    condition = inputs[0]
    x = inputs[1]
    y = inputs[2]
    one = dc.tensor(torch.ones((1,)))
    not_condition = dc.op("subtract", [one, condition])

    t0 = dc.op("multiply", [condition, x])
    t1 = dc.op("multiply", [not_condition, y])

    add = dc.op("add", [t0, t1])
    dc.fuse(add)

In some cases, mask tensor(inputs[0]) may contain values (example : 3.4028234663852886e+38 -> torch.finfo(torch.float32).min -> highest representable value for a float32 ) other than 0s and 1s. subtraction of even very small values to this type of mask values should generate -infinity values which will lead to NaN tensor. To avoid that,

Fill values are getting replaced with ±1e4 without checking whether it is inf/-inf or not. To avoid that,