In the current example from SD Clip, there are 2 key kernels add_kernel and mul_add_kernel. After the computation of these kernels are completed, their results are fed into an mlir kernel which I think is a convolution.
I am implementing a scalar_mul_add in simplify_algebra. This will specifically handle where the multiplication operation is scalar and followed by a convolution. If these cases are met then directly perform a * (x+b) without expanding or rewriting it.
There needs to be an exception in mul_add. If the multiplication is scalar then it should not match and hence rewrite will not be done.
Draft PR
In the current example from SD Clip, there are 2 key kernels add_kernel and mul_add_kernel. After the computation of these kernels are completed, their results are fed into an mlir kernel which I think is a convolution.
I am implementing a scalar_mul_add in simplify_algebra. This will specifically handle where the multiplication operation is scalar and followed by a convolution. If these cases are met then directly perform a * (x+b) without expanding or rewriting it.
There needs to be an exception in mul_add. If the multiplication is scalar then it should not match and hence rewrite will not be done.