elif type == "where":
condition = inputs[0]
x = inputs[1]
y = inputs[2]
one = dc.tensor(torch.ones((1,)))
not_condition = dc.op("subtract", [one, condition])
t0 = dc.op("multiply", [condition, x])
t1 = dc.op("multiply", [not_condition, y])
add = dc.op("add", [t0, t1])
dc.fuse(add)
In some cases, mask tensor(inputs[0]) may contain values (example : 3.4028234663852886e+38 -> torch.finfo(torch.float32).min -> highest representable value for a float32 ) other than 0s and 1s. subtraction of even very small values to this type of mask values should generate -infinity values which will lead to NaN tensor. To avoid that,
Added preprocessing steps to convert condition(mask) into binary tensor:
Cast the condition to float32.
Applied absolute value and clipped it between 0 and 1.0.
Converted the condition to binary (0 or 1) by comparing it to 0.
Fill values are getting replaced with ±1e4 without checking whether it is inf/-inf or not. To avoid that,
Enhanced the handling of -inf/inf values in the fill value
Added a condition to check if the value is infinity/ -Infinity .If it is true, then only replacement done with ±1e4 based on the sign.
In pybuda where op is decomposed in a below way
In some cases, mask tensor(inputs[0]) may contain values (example : 3.4028234663852886e+38 -> torch.finfo(torch.float32).min -> highest representable value for a float32 ) other than 0s and 1s. subtraction of even very small values to this type of mask values should generate -infinity values which will lead to NaN tensor. To avoid that,
Fill values are getting replaced with ±1e4 without checking whether it is inf/-inf or not. To avoid that,