ROCm / AMDMIGraphX

AMD's graph optimization engine.
https://rocm.docs.amd.com/projects/AMDMIGraphX/en/latest/
MIT License
185 stars 85 forks source link

Accuracy issue with fuse_pointwise and GridSample #2923

Closed gyulaz-htec closed 5 months ago

gyulaz-htec commented 7 months ago

The issue is related to the GridSample operator (PR). When the interpolation mode is linear the following opreators are fused: equal, logical_and, where, mul, add in a way that the GPU result differs from the REF result.

The branch with the issue: https://github.com/ROCm/AMDMIGraphX/tree/grid_sample The code path for GridSample linear interpolation: https://github.com/ROCm/AMDMIGraphX/pull/2909/files#diff-edf0354cc6d2ed3d8812944ab5d1d60ef374ff2ec5c5529991de6fdc7c16b83bR271-R328 Reproduce steps:

Logs for the above mentioned steps: gridsample_no_fusion.log gridsample_fusion.log

pfultz2 commented 5 months ago

Can you paste the output of migraphx-driver read ./test/onnx/gridsample_simple_test.onnx --py here? This way we can debug this without needing GridSample branch.

Very likely we need to add logical_and implementation here. The migraphx::abs is added to those operators because the vector types return -1,0,1 instead of just 0,1.

Some things to try to debug it further:

1) Run with MIGRAPHX_GPU_DEBUG=1 2) Disable vectorization(would need to change the source code here)

gyulaz-htec commented 5 months ago

@pfultz2 the output of migraphx-driver read ./test/onnx/gridsample_simple_test.onnx --py:

p = migraphx.program()
m = p.get_main_module()
x_0 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[1,2]), 0))
x_1 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="int64_type", lens=[1,3]), 1))
x_2 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[1]), 2))
x_3 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[1]), 3))
x_4 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[1]), 4))
x_5 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[1]), 5))
x_6 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[1]), 6))
x_7 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[1]), 7))
x_8 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[1]), 8))
x_9 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[1]), 9))
x_10 = m.add_literal(migraphx.generate_argument(migraphx.shape(type="float_type", lens=[1]), 10))
p_grid = m.add_parameter("grid",migraphx.shape(type="float_type", lens=[1,1,1,2]))
p_x = m.add_parameter("x",migraphx.shape(type="float_type", lens=[1,1,2,2]))
x_13 = m.add_instruction(migraphx.op("slice", axes=[3], starts=[0], ends=[1]), [p_grid])
x_14 = m.add_instruction(migraphx.op("slice", axes=[3], starts=[1], ends=[2]), [p_grid])
x_15 = m.add_instruction(migraphx.op("squeeze", axes=[3]), [x_13])
x_16 = m.add_instruction(migraphx.op("squeeze", axes=[3]), [x_14])
x_17 = m.add_instruction(migraphx.op("multibroadcast", out_lens=[1,1,1]), [x_9])
x_18 = m.add_instruction(migraphx.op("add"), [x_15, x_17])
x_19 = m.add_instruction(migraphx.op("multibroadcast", out_lens=[1,1,1]), [x_3])
x_20 = m.add_instruction(migraphx.op("mul"), [x_18, x_19])
x_21 = m.add_instruction(migraphx.op("multibroadcast", out_lens=[1,1,1]), [x_8])
x_22 = m.add_instruction(migraphx.op("add"), [x_20, x_21])
x_23 = m.add_instruction(migraphx.op("multibroadcast", out_lens=[1,1,1]), [x_9])
x_24 = m.add_instruction(migraphx.op("add"), [x_16, x_23])
x_25 = m.add_instruction(migraphx.op("multibroadcast", out_lens=[1,1,1]), [x_2])
x_26 = m.add_instruction(migraphx.op("mul"), [x_24, x_25])
x_27 = m.add_instruction(migraphx.op("multibroadcast", out_lens=[1,1,1]), [x_8])
x_28 = m.add_instruction(migraphx.op("add"), [x_26, x_27])
x_29 = m.add_instruction(migraphx.op("floor"), [x_22])
x_30 = m.add_instruction(migraphx.op("floor"), [x_28])
x_31 = m.add_instruction(migraphx.op("multibroadcast", out_lens=[1,1,1]), [x_9])
x_32 = m.add_instruction(migraphx.op("add"), [x_29, x_31])
x_33 = m.add_instruction(migraphx.op("multibroadcast", out_lens=[1,1,1]), [x_9])
x_34 = m.add_instruction(migraphx.op("add"), [x_30, x_33])
x_35 = m.add_instruction(migraphx.op("sub"), [x_22, x_29])
x_36 = m.add_instruction(migraphx.op("sub"), [x_28, x_30])
x_37 = m.add_instruction(migraphx.op("multibroadcast", out_lens=[1,1,1]), [x_9])
x_38 = m.add_instruction(migraphx.op("sub"), [x_37, x_35])
x_39 = m.add_instruction(migraphx.op("multibroadcast", out_lens=[1,1,1]), [x_9])
x_40 = m.add_instruction(migraphx.op("sub"), [x_39, x_36])
x_41 = m.add_instruction(migraphx.op("mul"), [x_40, x_38])
x_42 = m.add_instruction(migraphx.op("mul"), [x_40, x_35])
x_43 = m.add_instruction(migraphx.op("mul"), [x_36, x_38])
x_44 = m.add_instruction(migraphx.op("mul"), [x_36, x_35])
x_45 = m.add_instruction(migraphx.op("gathernd"), [x_30, x_1])
x_46 = m.add_instruction(migraphx.op("gathernd"), [x_29, x_1])
x_47 = m.add_instruction(migraphx.op("gathernd"), [x_34, x_1])
x_48 = m.add_instruction(migraphx.op("gathernd"), [x_32, x_1])
x_49 = m.add_instruction(migraphx.op("clip"), [x_45, x_10, x_5])
x_50 = m.add_instruction(migraphx.op("equal"), [x_45, x_49])
x_51 = m.add_instruction(migraphx.op("clip"), [x_46, x_10, x_7])
x_52 = m.add_instruction(migraphx.op("equal"), [x_46, x_51])
x_53 = m.add_instruction(migraphx.op("clip"), [x_47, x_10, x_5])
x_54 = m.add_instruction(migraphx.op("equal"), [x_47, x_53])
x_55 = m.add_instruction(migraphx.op("clip"), [x_48, x_10, x_7])
x_56 = m.add_instruction(migraphx.op("equal"), [x_48, x_55])
x_57 = m.add_instruction(migraphx.op("reshape", dims=[1,1]), [x_49])
x_58 = m.add_instruction(migraphx.op("reshape", dims=[1,1]), [x_51])
x_59 = m.add_instruction(migraphx.op("reshape", dims=[1,1]), [x_53])
x_60 = m.add_instruction(migraphx.op("reshape", dims=[1,1]), [x_55])
x_61 = m.add_instruction(migraphx.op("concat", axis=1), [x_57, x_58])
x_62 = m.add_instruction(migraphx.op("concat", axis=1), [x_0, x_61])
x_63 = m.add_instruction(migraphx.op("concat", axis=1), [x_57, x_60])
x_64 = m.add_instruction(migraphx.op("concat", axis=1), [x_0, x_63])
x_65 = m.add_instruction(migraphx.op("concat", axis=1), [x_59, x_58])
x_66 = m.add_instruction(migraphx.op("concat", axis=1), [x_0, x_65])
x_67 = m.add_instruction(migraphx.op("concat", axis=1), [x_59, x_60])
x_68 = m.add_instruction(migraphx.op("concat", axis=1), [x_0, x_67])
x_69 = m.add_instruction(migraphx.op("logical_and"), [x_52, x_50])
x_70 = m.add_instruction(migraphx.op("logical_and"), [x_56, x_50])
x_71 = m.add_instruction(migraphx.op("logical_and"), [x_52, x_54])
x_72 = m.add_instruction(migraphx.op("logical_and"), [x_56, x_54])
x_73 = m.add_instruction(migraphx.op("reshape", dims=[1,3]), [x_1])
x_74 = m.add_instruction(migraphx.op("gathernd"), [p_x, x_62])
x_75 = m.add_instruction(migraphx.op("where"), [x_69, x_74, x_10])
x_76 = m.add_instruction(migraphx.op("gathernd"), [p_x, x_64])
x_77 = m.add_instruction(migraphx.op("where"), [x_70, x_76, x_10])
x_78 = m.add_instruction(migraphx.op("gathernd"), [p_x, x_66])
x_79 = m.add_instruction(migraphx.op("where"), [x_71, x_78, x_10])
x_80 = m.add_instruction(migraphx.op("gathernd"), [p_x, x_68])
x_81 = m.add_instruction(migraphx.op("where"), [x_72, x_80, x_10])
x_82 = m.add_instruction(migraphx.op("gathernd"), [x_41, x_73])
x_83 = m.add_instruction(migraphx.op("mul"), [x_75, x_82])
x_84 = m.add_instruction(migraphx.op("gathernd"), [x_42, x_73])
x_85 = m.add_instruction(migraphx.op("mul"), [x_77, x_84])
x_86 = m.add_instruction(migraphx.op("gathernd"), [x_43, x_73])
x_87 = m.add_instruction(migraphx.op("mul"), [x_79, x_86])
x_88 = m.add_instruction(migraphx.op("gathernd"), [x_44, x_73])
x_89 = m.add_instruction(migraphx.op("mul"), [x_81, x_88])
x_90 = m.add_instruction(migraphx.op("add"), [x_83, x_85])
x_91 = m.add_instruction(migraphx.op("add"), [x_90, x_87])
x_92 = m.add_instruction(migraphx.op("add"), [x_91, x_89])
gyulaz-htec commented 5 months ago

@pfultz2 I think I've found the issue with the equal_equal_logical_and_where_mul_equal_logical_and_where_mul_equal_logical_and_where_mul_logical_and_where_mul_add_add_add_kernel kernel function. The order of params/attributes differ betwen the preamble and the corresponding kernel function.

preamble:

template<class Tx0, class Tx1, class Tx10, class Tx11, class Tx12, class Tx13, class Tx14, class Tx15, class Tx2, class Tx3, class Tx4, class Tx5, class Tx6, class Tx7, class Tx8, class Tx9>
__device__ __attribute__((const)) auto inner_pointwise(Tx0 x0,Tx1 x1,Tx10 x10,Tx11 x11,Tx12 x12,Tx13 x13,Tx14 x14,Tx15 x15,Tx2 x2,Tx3 x3,Tx4 x4,Tx5 x5,Tx6 x6,Tx7 x7,Tx8 x8,Tx9 x9) {

kernel function:

MIGRAPHX_GLOBAL void equal_equal_logical_and_where_mul_equal_logical_and_where_mul_equal_logical_and_where_mul_logical_and_where_mul_add_add_add_kernel(void * private_p0,void * private_p1,void * private_p2,void * private_p3,void * private_p4,void * private_p5,void * private_p6,void * private_p7,void * private_p8,void * private_p9,void * private_p10,void * private_p11,void * private_p12,void * private_p13,void * private_p14,void * private_p15,void * private_p16) 
{
    auto idx = make_index();
    pointwise<1>(idx, vectorize<1, 0>())(MIGRAPHX_LIFT(inner_pointwise), private_p0,private_p1,private_p2,private_p3,private_p4,private_p5,private_p6,private_p7,private_p8,private_p9,private_p10,private_p11,private_p12,private_p13,private_p14,private_p15,private_p16);
}

Notice that private_p2 is passed dow in place of the 11th argument x10, this also applies to private_p3 ... private_p7

After I've fixed the order of arguments in the kernel call the issue is fixed. I will tidy up my solution and create a PR for it.

The full kernel function:

#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include <args.hpp>

namespace migraphx {

template<class Tx0, class Tx1, class Tx10, class Tx11, class Tx12, class Tx13, class Tx14, class Tx15, class Tx2, class Tx3, class Tx4, class Tx5, class Tx6, class Tx7, class Tx8, class Tx9>
__device__ __attribute__((const)) auto inner_pointwise(Tx0 x0,Tx1 x1,Tx10 x10,Tx11 x11,Tx12 x12,Tx13 x13,Tx14 x14,Tx15 x15,Tx2 x2,Tx3 x3,Tx4 x4,Tx5 x5,Tx6 x6,Tx7 x7,Tx8 x8,Tx9 x9) {
// @param:x5 -> float_type, {1}, {0}
// @literal -> float_type, {1}, {0}
auto zz1 = float(0);
// @param:x4 -> float_type, {1}, {0}
// @param:x1 -> float_type, {1}, {0}
// @param:x0 -> float_type, {1}, {0}
// equal -> float_type, {1}, {0}
auto zz5 = migraphx::convert<float>(migraphx::abs(x0 == x1));
// @param:x3 -> float_type, {1}, {0}
// @param:x2 -> float_type, {1}, {0}
// equal -> float_type, {1}, {0}
auto zz8 = migraphx::convert<float>(migraphx::abs(x2 == x3));
// logical_and -> float_type, {1}, {0}
auto zz9 = migraphx::convert<float>(zz8 && zz5);
// where -> float_type, {1}, {0}
auto zz10 = migraphx::convert<float>(migraphx::where(zz9, x4, zz1));
// mul -> float_type, {1}, {0}
auto zz11 = migraphx::convert<float>(zz10 * x5);
// @param:x9 -> float_type, {1}, {0}
// @param:x8 -> float_type, {1}, {0}
// @param:x7 -> float_type, {1}, {0}
// @param:x6 -> float_type, {1}, {0}
// equal -> float_type, {1}, {0}
auto zz16 = migraphx::convert<float>(migraphx::abs(x6 == x7));
// logical_and -> float_type, {1}, {0}
auto zz17 = migraphx::convert<float>(zz16 && zz5);
// where -> float_type, {1}, {0}
auto zz18 = migraphx::convert<float>(migraphx::where(zz17, x8, zz1));
// mul -> float_type, {1}, {0}
auto zz19 = migraphx::convert<float>(zz18 * x9);
// @param:x13 -> float_type, {1}, {0}
// @param:x12 -> float_type, {1}, {0}
// @param:x11 -> float_type, {1}, {0}
// @param:x10 -> float_type, {1}, {0}
// equal -> float_type, {1}, {0}
auto zz24 = migraphx::convert<float>(migraphx::abs(x10 == x11));
// logical_and -> float_type, {1}, {0}
auto zz25 = migraphx::convert<float>(zz8 && zz24);
// where -> float_type, {1}, {0}
auto zz26 = migraphx::convert<float>(migraphx::where(zz25, x12, zz1));
// mul -> float_type, {1}, {0}
auto zz27 = migraphx::convert<float>(zz26 * x13);
// @param:x15 -> float_type, {1}, {0}
// @param:x14 -> float_type, {1}, {0}
// logical_and -> float_type, {1}, {0}
auto zz30 = migraphx::convert<float>(zz16 && zz24);
// where -> float_type, {1}, {0}
auto zz31 = migraphx::convert<float>(migraphx::where(zz30, x14, zz1));
// mul -> float_type, {1}, {0}
auto zz32 = migraphx::convert<float>(zz31 * x15);
// add -> float_type, {1}, {0}
auto zz33 = migraphx::convert<float>(zz32 + zz27);
// add -> float_type, {1}, {0}
auto zz34 = migraphx::convert<float>(zz33 + zz19);
// add -> float_type, {1}, {0}
auto zz35 = migraphx::convert<float>(zz34 + zz11);
// @return -> float_type, {1}, {0}
auto zzreturn = make_tuple(zz35);
return zzreturn;

}

extern "C" {
MIGRAPHX_GLOBAL void equal_equal_logical_and_where_mul_equal_logical_and_where_mul_equal_logical_and_where_mul_logical_and_where_mul_add_add_add_kernel(void * private_p0,void * private_p1,void * private_p2,void * private_p3,void * private_p4,void * private_p5,void * private_p6,void * private_p7,void * private_p8,void * private_p9,void * private_p10,void * private_p11,void * private_p12,void * private_p13,void * private_p14,void * private_p15,void * private_p16) 
{
    auto idx = make_index();
    pointwise<1>(idx, vectorize<1, 0>())(MIGRAPHX_LIFT(inner_pointwise), private_p0,private_p1,private_p2,private_p3,private_p4,private_p5,private_p6,private_p7,private_p8,private_p9,private_p10,private_p11,private_p12,private_p13,private_p14,private_p15,private_p16);
}

}

} // namespace migraphx
pfultz2 commented 5 months ago

Ah I know the issue, its uses more than 10 parameters, which messes up the sorting.

pfultz2 commented 5 months ago

I added the param_name function in order to put a name that can be sorted correctly. Let me update all of our fusions to use that(and update it to handle numbers greater than 10).