ROCm / AMDMIGraphX

AMD's graph optimization engine.
https://rocm.docs.amd.com/projects/AMDMIGraphX/en/latest/
MIT License
184 stars 82 forks source link

Resnext50 failing to run on MIGraphX Driver #1283

Closed causten closed 1 year ago

causten commented 2 years ago

While trying https://zenodo.org/record/6617879/files/resnext50_32x4d_fpn.onnx with... migraphx-driver perf ./resnext50_32x4d_fpn.onnx

I hit an error indicating an operator is not supported by migraphx... what(): /workspace/AMDMIGraphX/src/onnx/onnx_parser.cpp:318: parse_graph: Unknown operator: Mod

Investigate and add

TedThemistokleous commented 2 years ago

Doing a quick check with the driver listing our operators out, it isn't implemented.

Started something on this today but hit a bit of a snag since the mod operator (%) doesn't work implicitly with non integer types and resorted to using fmod for float, double and half. Looks like we implicitly can use half_float::half for this and I've setup my templates accordingly to pick the right modulus call.

Got some help from Umang about how to setup the compute_shape() call, and have some tests running so far for reference.

Taking a look at the Onnx spec there's some oddity here with the function being different based on the fmod attribute.

https://github.com/onnx/onnx/blob/main/docs/Operators.md#Mod

Still determining how to go about adding this based on the input data type but so far I have the base functionality working in tests.

Goal is to hammer this to get the resnext50 model working correctly hopefully tomorrow/Monday.

pfultz2 commented 2 years ago

For the gpu side, we will need to add a mod functions to math.hpp, if the class/operator name is mod then it will use the function migraphx::mod on the gpu by default when using the op::binary<> class.

For the cpu/ref side we will probably need to add mod overloaded functions as well, probably for both mod and fmod.

Taking a look at the Onnx spec there's some oddity here with the function being different based on the fmod attribute.

We could add two different operators fmod and imod, but we would want it to work generically for all types though. Dependending on the fmod flag we would insert the appropriate operator.

TedThemistokleous commented 2 years ago

For the gpu side, we will need to add a mod functions to math.hpp, if the class/operator name is mod then it will use the function migraphx::mod on the gpu by default when using the op::binary<> class.

For the cpu/ref side we will probably need to add mod overloaded functions as well, probably for both mod and fmod.

Taking a look at the Onnx spec there's some oddity here with the function being different based on the fmod attribute.

We could add two different operators fmod and imod, but we would want it to work generically for all types though. Dependending on the fmod flag we would insert the appropriate operator.

There's a few ways we could go about it and I'm very much open to suggestions/preferences to what you prefer under the hood.

I just pushed my latest change set I'm working on but failing the test_verify case, since I'm still trying to figure out how these point_function plays a roll in this. I was using convert and mul operators that Umang pointed me to to get this going.

I'm trying to capture the default attribute nature of the fmod onnx operator, and have users explicitly set things as true when they want float functionality. Not sure if I should be testing the half case as well for the ref.

I feel like there needs to be a few more discussions around this one because of the difference from that one flag.

I appreciate the gotcha on the gpu case. I didn't even realize that.

pfultz2 commented 2 years ago

I just pushed my latest change set I'm working on

I did leave some comments, hopes thats okay.

I'm trying to capture the default attribute nature of the fmod onnx operator, and have users explicitly set things as true when they want float functionality.

The difference with the fmod flag is the difference between numpy.mod and numpy.fmod which is unrelated to whether the input is float or integer:

https://www.skytowner.com/explore/difference_between_the_methods_mod_and_fmod_in_numpy

You can also see that in test cases from the onnx spec as well, with fmod=1 using integer inputs uses numpy.fmod:

node = onnx.helper.make_node(
    'Mod',
    inputs=['x', 'y'],
    outputs=['z'],
    fmod=1
)

x = np.array([-4, 7, 5, 4, -7, 8]).astype(np.int64)
y = np.array([2, -3, 8, -2, 3, 5]).astype(np.int64)
z = np.fmod(x, y)  # expected output [ 0,  1,  5,  0, -1,  3]
expect(node, inputs=[x, y], outputs=[z],
       name='test_mod_int64_fmod')

And then using fmod=0 as the default uses numpy.mod:

ode = onnx.helper.make_node(
    'Mod',
    inputs=['x', 'y'],
    outputs=['z'],
)

x = np.array([-4, 7, 5, 4, -7, 8]).astype(np.int64)
y = np.array([2, -3, 8, -2, 3, 5]).astype(np.int64)
z = np.mod(x, y)  # expected output [ 0, -2,  5,  0,  2,  3]
expect(node, inputs=[x, y], outputs=[z],
       name='test_mod_mixed_sign_int64')

This is why I think it makes sense to have two operators mod and fmod. fmod can call std::fmod which also works for integers, and mod could call std::remainder or std::abs(std::fmod(...))(I am not sure which has the same behavior as numpy.mod).

since I'm still trying to figure out how these point_function plays a roll in this.

When we generate this code for runtime compilation it will use the point_function name to call the same function but prefixed with migraphx::.

In the math.hpp header here we define these function for all the data types including vector types. We currently dont have an overload for fmod and remainder. I think hip provides fmod but I dont know if it works for integers which means we would need to provide an overload for that. Also, I dont know know if hip provides remainder so we might need to put our own implementation there.

TedThemistokleous commented 2 years ago

No issue at all on your comments and I really appreciate the link. That clears up a bunch of ambiguity I had. I took a gander and found a few other things in regards to getting C++ to play nice and behave like python:

https://stackoverflow.com/questions/1907565/c-and-python-different-behaviour-of-the-modulo-operation

I had initial tried remainder but was getting differing results.

Following your advice I was able to split things into both equivalent fmod and mod functions, while removing all the previous templating I did. I had assumed we wanted it to be similar to C's mod functionality.

I've added some additional test cases for both integral & float cases. Let me know if I should also be trying half_type as a separate test. I'll push this as a PR to develop so we get both operators.

Do we need to do similar on the tensor flow side too once we add the onnx operator?

TedThemistokleous commented 2 years ago

@causten On another note, rerunning resnext after adding the mod operator, I'm getting an odd output now with parse_if failing

main:@2642 = convert[target_type=0](main:@2641) -> bool_type, {1}, {0}
main:@2643 = convert[target_type=0](main:@2642) -> bool_type, {1}, {0}

module: "If_28_if"
If_28_if:@0 = @literal{1} -> int64_type, {1}, {1}
If_28_if:@1 = squeeze[axes={1}](?) -> int64_type, {31702968}, {1}
If_28_if:@2 = @return(If_28_if:@1)

module: "If_2644_else"
If_2644_else:@0 = @literal{1} -> int64_type, {1}, {1}
If_2644_else:@1 = @literal{31702968, 1} -> int64_type, {2}, {1}
If_2644_else:@2 = @literal{1} -> int64_type, {1}, {1}
If_2644_else:@3 = @literal{2} -> int64_type, {1}, {1}
If_2644_else:@4 = @literal{0.5} -> float_type, {1}, {1}
If_2644_else:@5 = @literal{9223372036854775807} -> int64_type, {1}, {1}
If_2644_else:@6 = @literal{0} -> int64_type, {1}, {1}
If_2644_else:@7 = @literal{0} -> int64_type, {1}, {1}
If_2644_else:@8 = @literal{0} -> int64_type, {1}, {1}
If_2644_else:@9 = @literal{1} -> int64_type, {1}, {1}
If_2644_else:@10 = @literal{1} -> float_type, {1}, {0}
If_2644_else:@11 = reduce_max[axes={0, 1}](main:@2636) -> float_type, {1, 1}, {1, 1}
If_2644_else:@12 = squeeze[axes={0, 1}](If_2644_else:@11) -> float_type, {1}, {0}
If_2644_else:@13 = convert[target_type=2](main:@2638) -> float_type, {31702968}, {1}
If_2644_else:@14 = add(If_2644_else:@12,If_2644_else:@10) -> float_type, {1}, {0}
If_2644_else:@15 = multibroadcast[out_lens={31702968}](If_2644_else:@14) -> float_type, {31702968}, {0}
If_2644_else:@16 = mul(If_2644_else:@13,If_2644_else:@15) -> float_type, {31702968}, {1}
If_2644_else:@17 = unsqueeze[axes={1}](If_2644_else:@16) -> float_type, {31702968, 1}, {1, 1}
If_2644_else:@18 = multibroadcast[out_lens={31702968, 4}](If_2644_else:@17) -> float_type, {31702968, 4}, {1, 0}
If_2644_else:@19 = add(main:@2636,If_2644_else:@18) -> float_type, {31702968, 4}, {4, 1}
If_2644_else:@20 = unsqueeze[axes={0}](If_2644_else:@19) -> float_type, {1, 31702968, 4}, {126811872, 4, 1}
If_2644_else:@21 = unsqueeze[axes={0}](main:@2637) -> float_type, {1, 31702968}, {31702968, 1}
If_2644_else:@22 = unsqueeze[axes={0}](If_2644_else:@21) -> float_type, {1, 1, 31702968}, {31702968, 31702968, 1}
If_2644_else:@23 = nonmaxsuppression[center_point_box=0](If_2644_else:@20,If_2644_else:@22,If_2644_else:@5,If_2644_else:@4) -> int64_type, {31702968, 3}, {3, 1}
If_2644_else:@24 = gather[axis=1](If_2644_else:@23,If_2644_else:@3) -> int64_type, {31702968, 1}, {1, 1}
If_2644_else:@25 = gather[axis=0](If_2644_else:@1,If_2644_else:@2) -> int64_type, {1}, {1}
If_2644_else:@26 = equal(If_2644_else:@25,If_2644_else:@0) -> int64_type, {1}, {1}
If_2644_else:@27 = convert[target_type=0](If_2644_else:@26) -> bool_type, {1}, {1}

module: "If_2644_if"
If_2644_if:@0 = @literal{} -> float_type, {}, {}
If_2644_if:@1 = @return(If_2644_if:@0)

module: "If_28_else"
If_28_else:@0 = identity(If_2644_else:@24) -> int64_type, {31702968, 1}, {1, 1}
If_28_else:@1 = @return(If_28_else:@0)

terminate called after throwing an instance of 'migraphx::version_1::exception'
  what():  /code/AMDMIGraphX/src/onnx/parse_if.cpp:72: parse: PARSE_IF: then and else sub_grahps must have same output shapes!
Aborted (core dumped) 

Not sure what the course of action for debugging this is. I can look into this module.

TedThemistokleous commented 2 years ago

More progress on this, started looking at the resnext50 model in netron, it appears that we have some odd behavior near the end of the network before we get scores.

There's a chain of if/else that seem to have just passing through one value, whereas the other else branch modifies the dimension accordingly through the second chain. but uses identity() on output. I dont think we're handling the passed through constant correctly as it's assuming an empty constant, but that gets translated to a literal of float type with zero dimensions on our end

TedThemistokleous commented 2 years ago

So it appears our parsing of the this empty constant always defaults to a float type which is why we're getting such a type mismatch when we return the parsed constant as a literal.

TedThemistokleous commented 2 years ago

Mod operator ref complete. Still requires GPU side changes from #1306

Network parsing seems to be still broken due to issues with parse_if and related output shapes.

Currently being resolved in PR #1325

TedThemistokleous commented 2 years ago

@pfultz2 thanks for your help the other day on the GPU side of changes. I'm assuming Jit will handle the CPU target changes as well?

If not, do we have a doc on what jit supports/doesn't support currently? Where's the best place to look at that?

TedThemistokleous commented 1 year ago

An update on this. Got something sort of going but hitting a wall with the network due to it using a newer opset. Hopefully I'm at the tail end of handling issues with parse_if for static shapes pertaining to this network.

TedThemistokleous commented 1 year ago

Gotten changes for parse_if ready. Just need to sort out some additional model conversion details.

Looks like I'm close gotten a bit more time to play with this and had to fork onnx so I can add support for this conversion. I see relevant values for split being converted between opset 13->12 with the resnext50 model.

Currently we're breaking on some other modules now too so I dove into the onnx code to solve this.

opened up a question and a PR to that project to get answers. Tried things locally to add support for missing adapters for conversion but I keep getting odd errors still.

PR: https://github.com/onnx/onnx/pull/4615 Issue: https://github.com/onnx/onnx/issues/4616

pfultz2 commented 1 year ago

Why can't we use the latest opset version for resnext50 in migraphx?

TedThemistokleous commented 1 year ago

Difference between latest version of split vs our current implementation of split

https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Split-11

https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Split-13

The resnext50 model uses split as an input instead of an attribute so we fail and get

terminate called after throwing an instance of 'migraphx::version_1::exception'
  what():  /code/AMDMIGraphX/src/onnx/parse_split.cpp:71: parse: PARSE_SPLIT: input cannot be equally divided into 4 splits!
Aborted (core dumped)
pfultz2 commented 1 year ago

The resnext50 model uses split as an input instead of an attribute so we fail and get

We should fix those issues with split instead of trying to convert to an older opset. Its seems much easier to fix as well.

TedThemistokleous commented 1 year ago

I may have a work around for this, which does involve adding in a missing operator actually.

splitToSequence, and sequenceAt alongside modifying the network via : https://github.com/ZhangGe6/onnx-modifier

Looks like I can insert a new node and split this up that way too.

The problem is here, so split doesn't work because it got updated in opset 13. So next opset update do we want to update things? It doesn't seem like we have a good provision to handling networks with new opsets and attributes/things switched around outside of rebuilding an older MIGraphX lib.

pfultz2 commented 1 year ago

So next opset update do we want to update things?

What do you mean opset update?

It doesn't seem like we have a good provision to handling networks with new opsets

We update the opset version for an operator usually when we have new models that are using those version of operators. We dont have a good test suite for this so its usually just for whats in the models we are testing.

attributes/things switched around outside of rebuilding an older MIGraphX lib.

We should support all the opsets, you shouldn't need to build an older version of MIGraphX. Take a look at parse_reshape, it will use the attribute when available or it will use the input passed as a parameter. So a model could be either opset.

TedThemistokleous commented 1 year ago

Yeah I resorted to that and have a fix I'll push up shortly. Talking to the onnx team, it looks like they really only support moving older models forward, and moving newer models backwards to older opsets is loosely supported as some older ops won't support specific data types..

So the model converter should work if we want to bring everything up to a certain version that we want to try using say newer data types (bfloat comes to mind) but going backwards seems like a fruitless en devour now as the support for that is spotty.

TedThemistokleous commented 1 year ago

Maybe a solution for this at a later time is to use the conversion script anyway on older models moving forward to see if we have any existing deficiencies while retaining our existing models that originally work. That could buy us some assurances that anyone updating to a newer model wouldn't break compatibility based on those tests. I mean the existing models we tests are good real world examples of how things are used.

TedThemistokleous commented 1 year ago

So I've been able to show shape inferences on this model using the changes from the add_onnx_opset_converter found at #1413

I was able to just "Turn on shape inference" and found the following:

image

the unk_## indicates that the input to this last set of commands requires a dynamic shape to correctly infer the results.

Albiet the split operator did need a fix for handling between split versions (split-11, split-13) but i was seeing errors with output shapes on this.

TedThemistokleous commented 1 year ago

Sorted the shape thing out. Realized that I was unsqueezing another dimension for the smaller output case rather than squeezing down the larger dimension in the IF. Doing this seemed to work with some other networks I tried (mainly the stuff relating splitToSequence) and I was able to get a read of the network.

Currently able to get our driver to run perf with a ref target: debug_resnext50_flags.txt Excerpt of this right now from the attached:

Summary:
ref::convolution: 706985ms / 111 = 6369.23ms, 73%
ref::op: 227199ms / 1835 = 123.814ms, 24%
if: 38688.3ms / 2 = 19344.1ms, 4%
@literal: 47.7884ms / 394 = 0.12129ms, 1%
@param: 0.00366ms / 1 = 0.00366ms, 1%

Batch size: 1
Rate: 0.00106105/sec
Total time: 942464ms
Total instructions time: 972920ms
Overhead time: 0.835036ms, -30455.5ms
Overhead: 0%, -3%

Currently working on getting the gpu implementation running. Currently seeing failure when I use pointwise fusions, so I've removed that but now I keep getting a segfault. Took a back trace with lldb-12 and will go through this in the monring: debug_resnext50_gpu.txt

I may need another pair of eyes on this.

On the matter of parse_if, I just need to sort out some other changes based on PR comments then to get resnext50 to work but should be enough to get resnext50 going in its current state for debugging.

zack-ch commented 1 year ago

Sorted the shape thing out. Realized that I was unsqueezing another dimension for the smaller output case rather than squeezing down the larger dimension in the IF. Doing this seemed to work with some other networks I tried (mainly the stuff relating splitToSequence) and I was able to get a read of the network.

Currently able to get our driver to run perf with a ref target: debug_resnext50_flags.txt Excerpt of this right now from the attached:

Summary:
ref::convolution: 706985ms / 111 = 6369.23ms, 73%
ref::op: 227199ms / 1835 = 123.814ms, 24%
if: 38688.3ms / 2 = 19344.1ms, 4%
@literal: 47.7884ms / 394 = 0.12129ms, 1%
@param: 0.00366ms / 1 = 0.00366ms, 1%

Batch size: 1
Rate: 0.00106105/sec
Total time: 942464ms
Total instructions time: 972920ms
Overhead time: 0.835036ms, -30455.5ms
Overhead: 0%, -3%

Currently working on getting the gpu implementation running. Currently seeing failure when I use pointwise fusions, so I've removed that but now I keep getting a segfault. Took a back trace with lldb-12 and will go through this in the monring: debug_resnext50_gpu.txt

I may need another pair of eyes on this.

On the matter of parse_if, I just need to sort out some other changes based on PR comments then to get resnext50 to work but should be enough to get resnext50 going in its current state for debugging.

Following current state. Found that this segmentation fault is from simplify_reshape pass. Debug build with command:

root@mun-node-4:~/AMDMIGraphX/build# MIGRAPHX_TRACE_PASSES=1 MIGRAPHX_DISABLE_POINTWISE_FUSION=1 migraphx-driver perf --gpu /datasets/resnext50_32x4d_fpn.onnx 2>&1 | tee log.txt

// Log skipped
Module: If_28_else, Pass: simplify_reshapes
If_28_else:@0 = convert[target_type=2](?) -> float_type, {31702968, 1}, {1, 1}
If_28_else:@1 = identity(If_28_else:@0) -> float_type, {31702968, 1}, {1, 1}
If_28_else:@2 = convert[target_type=9](If_28_else:@1) -> int64_type, {31702968, 1}, {1, 1}
If_28_else:@3 = identity(?) -> int64_type, {31702968, 1}, {1, 1}
If_28_else:@4 = convert[target_type=2](If_28_else:@2) -> float_type, {31702968, 1}, {1, 1}
If_28_else:@5 = squeeze[axes={1}](If_28_else:@4) -> float_type, {31702968}, {1}
If_28_else:@6 = convert[target_type=9](If_28_else:@5) -> int64_type, {31702968}, {1}
If_28_else:@7 = squeeze[axes={1}](If_28_else:@2) -> int64_type, {31702968}, {1}
If_28_else:@8 = @return(If_28_else:@6)

convert[target_type=9](gather[axis=1]) -> int64_type, {31702968, 1}, {1, 1}
gather[axis=1](convert[target_type=2], convert[target_type=2]) -> float_type, {31702968, 1}, {1, 1}
convert[target_type=2](convert[target_type=9]) -> float_type, {31702968, 1}, {1, 1}
migraphx-driver: /root/AMDMIGraphX/src/module.cpp:323: migraphx::instruction_ref migraphx::module::replace_instruction(migraphx::instruction_ref, migraphx::instruction_ref): Assertion `has_instruction(rep)' failed.

Code modified to dump last information:

root@mun-node-4:~/AMDMIGraphX/build# git diff ../src/
diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp
index 4105fc2d4..6d0117705 100644
--- a/src/simplify_reshapes.cpp
+++ b/src/simplify_reshapes.cpp
@@ -194,7 +194,10 @@ struct find_nested_convert
         auto ins   = mr.result;
         auto x     = ins->inputs().front();
         auto input = x->inputs().front();
-
+       std::cout << m << std::endl;
+       x->debug_print();
+       input->debug_print();
+       ins->debug_print();
         if(ins->get_shape() != input->get_shape())
             return;
TedThemistokleous commented 1 year ago

Yes, I'm at the same state with this right now Zac. Thanks for confirming. You need to turn off pointwise fusions to get past this.

If you run this with export MIGRAPHX_DISABLE_POINTWISE_FUSION=1 you'll turn off fusions and avoid this.

You can also trace compile by using export MIGRAPHX_TRACE_COMPILE=1 to debug at compile time so you don't need to modify the code as you debug to get an output. You can trace evaluation similarity with:

export MIGRAPHX_TRACE_EVAL=2

Here's what I get when I turn on eval tracing and pointwise fusions

Allocating params ... 
Running performance report ... 
Run instruction: main:@0 = check_context::migraphx::version_1::gpu::context -> float_type, {}, {}
Time: 0.00755ms, 0.00818ms
Run instruction: main:@1 = hip::hip_allocate_memory[shape=float_type, {484627968}, {1},id=main:scratch] -> float_type, {484627968}, {1}
Time: 0.00938ms, 0.00982ms
Output has subnormal, zero, nan, normal
Output: -0.372549, -0.372549, -0.372549, -0.372549, -0.372549, ..., -0.372549, -0.372549, -0.372549, -0.372549, -0.372549
Run instruction: main:@2 = hip::hip_copy_literal[id=main:@literal:67] -> float_type, {1}, {0}
Time: 0.01171ms, 0.01224ms
Output has zero
Output: 0
Run instruction: main:@3 = hip::hip_copy_literal[id=main:@literal:77] -> float_type, {3, 1, 1}, {1, 1, 1}
Time: 0.00226ms, 0.0028ms
Output has normal
Output: 4.36681, 4.46429, 4.44444
Run instruction: main:@4 = hip::hip_copy_literal[id=main:@literal:65] -> float_type, {3, 1, 1}, {1, 1, 1}
Time: 0.001681ms, 0.001961ms
Output has normal
Output: -2.1179, -2.03571, -1.80444
Run instruction: main:@5 = hip::hip_copy_literal[id=main:@literal:57] -> float_type, {64, 3, 7, 7}, {147, 49, 7, 1}
Time: 0.00137ms, 0.00165ms
Output has normal
Output: 0.00142852, -0.00465941, -0.00834894, 0.00791459, -0.000184702, ..., 0.00842322, 0.00779237, -0.00279974, 0.00753029, 0.0132863
Run instruction: main:@6 = multibroadcast[out_lens={3, 800, 800}](main:@4) -> float_type, {3, 800, 800}, {1, 0, 0}
Time: 0.008001ms, 0.008421ms
Output has normal
Output: -2.1179, -2.1179, -2.1179, -2.1179, -2.1179, ..., -1.80444, -1.80444, -1.80444, -1.80444, -1.80444
Run instruction: main:@7 = multibroadcast[out_lens={3, 800, 800}](main:@3) -> float_type, {3, 800, 800}, {1, 0, 0}
Time: 0.00441ms, 0.00491ms
Output has normal
Output: 4.36681, 4.36681, 4.36681, 4.36681, 4.36681, ..., 4.44444, 4.44444, 4.44444, 4.44444, 4.44444
Run instruction: images = @param:images -> float_type, {1, 3, 800, 800}, {1920000, 640000, 800, 1}
Time: 0.00601ms, 0.0068ms
Run instruction: main:@9 = squeeze[axes={0}](images) -> float_type, {3, 800, 800}, {640000, 800, 1}
Time: 0.00865ms, 0.00917ms
Output has zero, normal
Output: -0.6875, 0.75, 0.5625, -0.625, -0.25, ..., 0.75, 0.5625, 0.625, -0.25, -0.4375
Run instruction: main:@10 = mul(main:@7,main:@9) -> float_type, {3, 800, 800}, {640000, 800, 1}

Full log is here: resnext50_gpu_debug.txt

TedThemistokleous commented 1 year ago

Doing more debug with lldb-12 running the following command after a few attempts at input params

bin/driver r perf ../resnext50_32x4d_fpn.onnx --fill0 images --input-dim @images 1 3 800 800 --output-names @boxes @labels @scores --disable-fast-math -n 1

with lldb-12

bin/driver read ../resnext50_32x4d_fpn.onnx   &>    resnext50_gpu_debug_with_lldb-12.txt
r perf ../resnext50_32x4d_fpn.onnx --fill0 images --input-dim @images 1 3 800 800 --output-names @boxes @labels @scores --disable-fast-math -n 1 

resnext50_gpu_lldb-12_debug.txt

Doing a backtrace once this faults

* thread #1, name = 'driver', stop reason = signal SIGSEGV: address access protected (fault address: 0x7ffeaea0b000)
  * frame #0: 0x00007fffcfb577f0 libmigraphx.so.2`auto void migraphx::version_1::detail::visit_all_flatten<migraphx::version_1::op::binary<migraphx::version_1::op::mul>::compute(migraphx::version_1::shape const&, std::vector<migraphx::version_1::argument, std::allocator<migraphx::version_1::argument> >) const::'lambda'(auto, auto, auto)&, auto migraphx::version_1::detail::visit_all_pack<migraphx::version_1::op::binary<migraphx::version_1::op::mul>::compute(migraphx::version_1::shape const&, std::vector<migraphx::version_1::argument, std::allocator<migraphx::version_1::argument> >) const::'lambda'(auto, auto, auto)&>(migraphx::version_1::shape const&, auto&&)::'lambda'(auto&&...)&, migraphx::version_1::argument&, migraphx::version_1::argument&, migraphx::version_1::argument&>(migraphx::version_1::shape const&, auto&&, auto&&, auto migraphx::version_1::detail::visit_all_pack<migraphx::version_1::op::binary<migraphx::version_1::op::mul>::compute(migraphx::version_1::shape const&, std::vector<migraphx::version_1::argument, std::allocator<migraphx::version_1::argument> >) const::'lambda'(auto, auto, auto)&>(migraphx::version_1::shape const&, auto&&)::'lambda'(auto&&...)&...)::'lambda'(auto)::operator()<migraphx::version_1::shape::as<float> >(auto) const + 272
    frame #1: 0x00007fffcfb56ed0 libmigraphx.so.2`void migraphx::version_1::shape::visit<void migraphx::version_1::detail::visit_all_flatten<migraphx::version_1::op::binary<migraphx::version_1::op::mul>::compute(migraphx::version_1::shape const&, std::vector<migraphx::version_1::argument, std::allocator<migraphx::version_1::argument> >) const::'lambda'(auto, auto, auto)&, auto migraphx::version_1::detail::visit_all_pack<migraphx::version_1::op::binary<migraphx::version_1::op::mul>::compute(migraphx::version_1::shape const&, std::vector<migraphx::version_1::argument, std::allocator<migraphx::version_1::argument> >) const::'lambda'(auto, auto, auto)&>(migraphx::version_1::shape const&, auto&&)::'lambda'(auto&&...)&, migraphx::version_1::argument&, migraphx::version_1::argument&, migraphx::version_1::argument&>(migraphx::version_1::shape const&, auto&&, auto&&, auto migraphx::version_1::detail::visit_all_pack<migraphx::version_1::op::binary<migraphx::version_1::op::mul>::compute(migraphx::version_1::shape const&, std::vector<migraphx::version_1::argument, std::allocator<migraphx::version_1::argument> >) const::'lambda'(auto, auto, auto)&>(migraphx::version_1::shape const&, auto&&)::'lambda'(auto&&...)&...)::'lambda'(auto), void migraphx::version_1::detail::visit_all_flatten<migraphx::version_1::op::binary<migraphx::version_1::op::mul>::compute(migraphx::version_1::shape const&, std::vector<migraphx::version_1::argument, std::allocator<migraphx::version_1::argument> >) const::'lambda'(auto, auto, auto)&, auto migraphx::version_1::detail::visit_all_pack<migraphx::version_1::op::binary<migraphx::version_1::op::mul>::compute(migraphx::version_1::shape const&, std::vector<migraphx::version_1::argument, std::allocator<migraphx::version_1::argument> >) const::'lambda'(auto, auto, auto)&>(migraphx::version_1::shape const&, auto&&)::'lambda'(auto&&...)&, migraphx::version_1::argument&, migraphx::version_1::argument&, migraphx::version_1::argument&>(migraphx::version_1::shape const&, auto&&, auto&&, auto migraphx::version_1::detail::visit_all_pack<migraphx::version_1::op::binary<migraphx::version_1::op::mul>::compute(migraphx::version_1::shape const&, std::vector<migraphx::version_1::argument, std::allocator<migraphx::version_1::argument> >) const::'lambda'(auto, auto, auto)&>(migraphx::version_1::shape const&, auto&&)::'lambda'(auto&&...)&...)::'lambda'()>(migraphx::version_1::shape::type_t, auto, auto) + 96
    frame #2: 0x00007fffcfb56e47 libmigraphx.so.2`migraphx::version_1::op::binary<migraphx::version_1::op::mul>::compute(migraphx::version_1::shape const&, std::vector<migraphx::version_1::argument, std::allocator<migraphx::version_1::argument> >) const + 199
    frame #3: 0x00007fffcfb59cc7 libmigraphx.so.2`decltype(fp0.compute(make_compute_output_shape(pack(fp0, fp2, fp3)), fp3)) migraphx::version_1::detail::compute_op<migraphx::version_1::op::mul, std::function<std::vector<migraphx::version_1::argument, std::allocator<migraphx::version_1::argument> > (migraphx::version_1::module*&, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, migraphx::version_1::argument, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, migraphx::version_1::argument> > > const&)> >(migraphx::version_1::rank<2>, migraphx::version_1::op::mul const&, migraphx::version_1::context&, migraphx::version_1::shape const&, std::vector<migraphx::version_1::argument, std::allocator<migraphx::version_1::argument> > const&, std::vector<migraphx::version_1::module*, std::allocator<migraphx::version_1::module*> > const&, std::function<std::vector<migraphx::version_1::argument, std::allocator<migraphx::version_1::argument> > (migraphx::version_1::module*&, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, migraphx::version_1::argument, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, migraphx::version_1::argument> > > const&)>) + 423
    frame #4: 0x00007fffcfb59a8c libmigraphx.so.2`migraphx::version_1::argument migraphx::version_1::detail::compute_op<migraphx::version_1::op::mul, std::function<std::vector<migraphx::version_1::argument, std::allocator<migraphx::version_1::argument> > (migraphx::version_1::module*&, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, migraphx::version_1::argument, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, migraphx::version_1::argument> > > const&)> >(migraphx::version_1::op::mul const&, migraphx::version_1::context&, migraphx::version_1::shape const&, std::vector<migraphx::version_1::argument, std::allocator<migraphx::version_1::argument> > const&, std::vector<migraphx::version_1::module*, std::allocator<migraphx::version_1::module*> > const&, std::function<std::vector<migraphx::version_1::argument, std::allocator<migraphx::version_1::argument> > (migraphx::version_1::module*&, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, migraphx::version_1::argument, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, migraphx::version_1::argument> > > const&)>) + 124
    frame #5: 0x00007fffcfb55b98 libmigraphx.so.2`migraphx::version_1::operation::private_detail_te_handle_type<migraphx::version_1::op::mul>::compute(migraphx::version_1::context&, migraphx::version_1::shape const&, std::vector<migraphx::version_1::argument, std::allocator<migraphx::version_1::argument> > const&, std::vector<migraphx::version_1::module*, std::allocator<migraphx::version_1::module*> > const&, std::function<std::vector<migraphx::version_1::argument, std::allocator<migraphx::version_1::argument> > (migraphx::version_1::module*&, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, migraphx::version_1::argument, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, migraphx::version_1::argument> > > const&)>) const + 56
    frame #6: 0x00007fffcf5a46d6 libmigraphx.so.2`std::vector<migraphx::version_1::argument, std::allocator<migraphx::version_1::argument> > migraphx::version_1::generic_eval<auto migraphx::version_1::program::eval(std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, migraphx::version_1::argument, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, migraphx::version_1::argument> > >, migraphx::version_1::execution_environment) const::$_1::operator()<migraphx::version_1::program::eval(std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, migraphx::version_1::argument, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, migraphx::version_1::argument> > >, migraphx::version_1::execution_environment) const::$_2>(migraphx::version_1::program::eval(std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, migraphx::version_1::argument, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, migraphx::version_1::argument> > >, migraphx::version_1::execution_environment) const::$_2) const::'lambda'(migraphx::version_1::program::eval(std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, migraphx::version_1::argument, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, migraphx::version_1::argument> > >, migraphx::version_1::execution_environment) const::$_2&&)>(migraphx::version_1::module const*, migraphx::version_1::context&, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, migraphx::version_1::argument, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, migraphx::version_1::argument> > >, std::unordered_map<std::_List_iterator<migraphx::version_1::instruction>, migraphx::version_1::argument, std::hash<std::_List_iterator<migraphx::version_1::instruction> >, std::equal_to<std::_List_iterator<migraphx::version_1::instruction> >, std::allocator<std::pair<std::_List_iterator<migraphx::version_1::instruction> const, migraphx::version_1::argument> > >, migraphx::version_1::program::eval(std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, migraphx::version_1::argument, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, migraphx::version_1::argument> > >, migraphx::version_1::execution_environment) const::$_2) + 2998
    frame #7: 0x00007fffcf5984a7 libmigraphx.so.2`migraphx::version_1::program::eval(std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, migraphx::version_1::argument, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, migraphx::version_1::argument> > >, migraphx::version_1::execution_environment) const + 1015
    frame #8: 0x00007fffcf59cb4d libmigraphx.so.2`migraphx::version_1::program::perf_report(std::ostream&, unsigned long, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, migraphx::version_1::argument, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, migraphx::version_1::argument> > >, unsigned long) const + 269
    frame #9: 0x000000000026b1bd driver`migraphx::driver::version_1::perf::run() + 605
    frame #10: 0x000000000026a963 driver`void migraphx::driver::version_1::run_command<migraphx::driver::version_1::perf>(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >, bool) + 1411
    frame #11: 0x000000000026a353 driver`int migraphx::driver::version_1::auto_register_command<migraphx::driver::version_1::perf>()::'lambda'(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >)::operator()(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >) const + 131
    frame #12: 0x000000000026a24d driver`std::_Function_handler<void (std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >), int migraphx::driver::version_1::auto_register_command<migraphx::driver::version_1::perf>()::'lambda'(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >)>::_M_invoke(std::_Any_data const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >&&) + 45
    frame #13: 0x00000000002455e2 driver`main + 834
    frame #14: 0x00007fffcdcb2083 libc.so.6`__libc_start_main(main=(driver`main), argc=18, argv=0x00007fffffffe408, init=<unavailable>, fini=<unavailable>, rtld_fini=<unavailable>, stack_end=0x00007fffffffe3f8) at libc-start.c:308:16
    frame #15: 0x00000000002451de driver`_start + 46

Doesn't seem like the fill on the input is doing anything. It seems like we're not filling images correctly and failing on the first mul after all our passes

I also had to turn off pointwise fusions to do this.

TedThemistokleous commented 1 year ago

After some more digging and help from Paul, looks like this PR removed a bunch of device functions that get used when running JIT.

PR #1394

As a work around I've reverted the changes as part of my debugging effort and go to the following point now during evaluation:

Output has zero
Output: 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0
Run instruction: main:@708 = transpose[permutation={1, 0}](main:@695) -> float_type, {116424, 1}, {1, 116424}
Time: 0.00349ms, 0.0042ms
Output has zero, normal
Output: 58337, 58338, 58601, 58602, 59130, ..., 0, 0, 0, 0, 0
Run instruction: main:@709 = load[offset=31562784,end=32028480](main:@1) -> float_type, {116424}, {1}
Time: 0.00204ms, 0.00255ms
Run instruction: main:@710 = gpu::code_object[code_object=26496,symbol_name=gathernd_kernel,global=61440,local=1024,](main:@667,main:@708,main:@709) -> float_type, {116424}, {1}
Time: 3.02973ms, 3.67899ms
Output has normal
Output: 0.0569803, 0.0555581, 0.067194, 0.0569749, 0.0538682, ..., 0.000887294, 0.000887294, 0.000887294, 0.000887294, 0.000887294
Run instruction: main:@711 = load[offset=0,end=1397088](main:@1) -> [float_type, {116424}, {1}, int64_type, {116424}, {1}]
Time: 0.016931ms, 0.017401ms
Run instruction: main:@712 = gpu::topk[k=116424,axis=0,largest=1](main:@710,main:@711) -> [float_type, {116424}, {1}, int64_type, {116424}, {1}]
Time: 1.79287ms, 105.989ms
terminate called after throwing an instance of 'migraphx::version_1::exception'
  what():  /code/AMDMIGraphX/src/include/migraphx/raw_data.hpp:104: operator(): Invalid tuple type

Which appeared to be an error with topk, but reducing trace level as the above was taken with MIGRAPHX_TRACE_EVAL=2, yields the following outcome failing on mod instead.

Run instruction: main:@749 = gpu::topk[k=23760000,axis=0,largest=1](main:@747,main:@748) -> [float_type, {23760000}, {1}, int64_type, {23760000}, {1}]
Time: 0.01392ms, 19118.1ms
Run instruction: main:@750 = get_tuple_elem[index=1](main:@728) -> int64_type, {401544}, {1}
Time: 0.00671ms, 0.00716ms
Run instruction: main:@751 = load[offset=416914432,end=418520608](main:@1) -> float_type, {401544}, {1}
Time: 0.003661ms, 0.004071ms
Run instruction: main:@752 = gpu::convert[target_type=2](main:@750,main:@751) -> float_type, {401544}, {1}
Time: 0.02063ms, 0.045221ms
Run instruction: main:@753 = squeeze[axes={1}](main:@724) -> float_type, {401544}, {1}
Time: 0.00183ms, 0.00215ms
Run instruction: main:@754 = load[offset=407415968,end=409022144](main:@1) -> float_type, {401544}, {1}
Time: 0.0009ms, 0.00143ms
Run instruction: main:@755 = gpu::gather[axis=0](main:@753,main:@752,main:@754) -> float_type, {401544}, {1}
Time: 0.013441ms, 0.036581ms
Run instruction: main:@756 = get_tuple_elem[index=1](main:@739) -> int64_type, {1485000}, {1}
Time: 0.00179ms, 0.0021ms
Run instruction: main:@757 = load[offset=415308256,end=421248256](main:@1) -> float_type, {1485000}, {1}
Time: 0.00079ms, 0.00111ms
Run instruction: main:@758 = gpu::convert[target_type=2](main:@756,main:@757) -> float_type, {1485000}, {1}
Time: 0.00324ms, 0.055451ms
Run instruction: main:@759 = squeeze[axes={1}](main:@735) -> float_type, {1485000}, {1}
Time: 0.00069ms, 0.001ms
Run instruction: main:@760 = load[offset=504458432,end=510398432](main:@1) -> float_type, {1485000}, {1}
Time: 0.00079ms, 0.00109ms
Run instruction: main:@761 = gpu::gather[axis=0](main:@759,main:@758,main:@760) -> float_type, {1485000}, {1}
Time: 0.00397ms, 0.049601ms
Run instruction: main:@762 = get_tuple_elem[index=1](main:@749) -> int64_type, {23760000}, {1}
Time: 0.00148ms, 0.00179ms
Run instruction: main:@763 = load[offset=630769824,end=725809824](main:@1) -> float_type, {23760000}, {1}
Time: 0.00083ms, 0.00124ms
Run instruction: main:@764 = gpu::convert[target_type=2](main:@762,main:@763) -> float_type, {23760000}, {1}
Time: 0.00334ms, 0.629164ms
Run instruction: main:@765 = load[offset=409368256,end=504408256](main:@1) -> float_type, {23760000}, {1}
Time: 0.00068ms, 0.00105ms
Run instruction: main:@766 = squeeze[axes={1}](main:@745) -> float_type, {23760000}, {1}
Time: 0.00103ms, 0.00134ms
Run instruction: main:@767 = gpu::gather[axis=0](main:@766,main:@764,main:@765) -> float_type, {23760000}, {1}
Time: 0.00371ms, 0.649255ms
Run instruction: main:@768 = hip::hip_copy_literal[id=main:@literal:153] -> float_type, {256}, {1}
Time: 0.00267ms, 0.00298ms
Run instruction: main:@769 = hip::hip_copy_literal[id=main:@literal:154] -> float_type, {256, 256, 3, 3}, {2304, 9, 3, 1}
Time: 0.00111ms, 0.00142ms
Run instruction: main:@770 = hip::hip_copy_literal[id=main:@literal:155] -> float_type, {256}, {1}
Time: 0.00096ms, 0.00126ms
Run instruction: main:@771 = hip::hip_copy_literal[id=main:@literal:156] -> float_type, {256, 256, 3, 3}, {2304, 9, 3, 1}
Time: 0.0013ms, 0.00161ms
Run instruction: main:@772 = hip::hip_copy_literal[id=main:@literal:157] -> float_type, {256}, {1}
Time: 0.00098ms, 0.00129ms
Run instruction: main:@773 = hip::hip_copy_literal[id=main:@literal:158] -> float_type, {256, 256, 3, 3}, {2304, 9, 3, 1}
Time: 0.00095ms, 0.00158ms
Run instruction: main:@774 = multibroadcast[out_lens={401544}](main:@2) -> float_type, {401544}, {0}
Time: 0.001851ms, 0.002201ms
Run instruction: main:@775 = multibroadcast[out_lens={401544}](main:@723) -> float_type, {401544}, {0}
Time: 0.00085ms, 0.00116ms
Run instruction: main:@776 = mod(main:@755,main:@775) -> float_type, {401544}, {1}

It appears we're still using the ref implementation with the mod operand, and no GPU function seems to be found. Not sure why the eval level obfuscated this. It appears that all topk instructions run in succession correctly.

resnext50_logging.txt (trace eval=1) resnext50_debug_gpu_revert_commit.txt (trace eval=2)

Looks like I need to determine why we're still calling the ref implementation here as I thought proper gpu ops were added via PR#1306

What's odd is in the referenced #1394 doesn't remove anything for mod either.

pfultz2 commented 1 year ago

Looks like I need to determine why we're still calling the ref implementation here as I thought proper gpu ops were added via PR#1306

It was added for the gpu but using the jit compilation, not the legacy precompiled operators.

What's odd is in the referenced https://github.com/ROCmSoftwarePlatform/AMDMIGraphX/pull/1394 doesn't remove anything for mod either.

It was never added as a legacy device function so there didnt exists anything to be removed.

TedThemistokleous commented 1 year ago

Taking a step back as it appears we require pointwise fusions to work in order for us to use the mod operator correctly, without adding some stopgap implementation for mod and undoing the previous mentioned PR #1394.

Debugging this:

Modifying logical_xor.hpp to show debug info via overloading compute_shape() with some extra printouts, we get the following inputs which triggers the check_shapes error:

bool_type, {1}, {0}, int8_type, {1}, {0} terminate called after throwing an instance of 'migraphx::version_1::exception' what(): /code/AMDMIGraphX/src/include/migraphx/check_shapes.hpp:179: same_type: logical_xor: Types do not match

I applied a quick fix to logical_xor bu overloading compute shape right now would yield a compiled output, but I think I still need to do a conversion as one input that's fused is using int8_type, and the other bool_type which I assume would be different sizes on the CPU and would require padding.

Right now I get the following output with my naive changes to PR #1458

Run instruction: main:@731 = gpu::code_object[code_object=13784,symbol_name=equal_convert_convert_not_less_convert_convert_logical_xor_mul_add_logical_and_where_kernel,global=371250,local=1024,](main:@715,main:@683,main:@730) -> float_type, {1485000}, {1}
Time: 0.00582ms, 0.073871ms
Run instruction: main:@732 = load[offset=681068608,end=682674784](main:@1) -> float_type, {1, 401544}, {401544, 1}
Time: 0.00063ms, 0.00095ms
Run instruction: main:@733 = gpu::code_object[code_object=13632,symbol_name=convert_kernel,global=100386,local=1024,](main:@726,main:@732) -> float_type, {1, 401544}, {401544, 1}
Time: 0.00549ms, 0.02602ms
Run instruction: main:@734 = load[offset=558799648,end=653839648](main:@1) -> float_type, {23760000}, {1}
Time: 0.0008ms, 0.00112ms
Run instruction: main:@735 = gpu::code_object[code_object=13784,symbol_name=equal_convert_convert_not_less_convert_convert_logical_xor_mul_add_logical_and_where_kernel,global=5940000,local=1024,](main:@719,main:@689,main:@734) -> float_type, {23760000}, {1}
Time: 0.00572ms, 0.86162ms
Run instruction: main:@736 = hip::hip_copy_literal[id=main:@literal:73] -> float_type, {5625, 4}, {4, 1}
Time: 0.00135ms, 0.00168ms
Run instruction: main:@737 = hip::hip_copy_literal[id=main:@literal:72] -> float_type, {22500, 4}, {4, 1}
Time: 0.00108ms, 0.0014ms
Run instruction: main:@738 = hip::hip_copy_literal[id=main:@literal:71] -> float_type, {1}, {0}
Time: 0.00103ms, 0.00136ms
Run instruction: main:@739 = load[offset=707299648,end=802339648](main:@1) -> float_type, {5940000, 4}, {4, 1}
Time: 0.00073ms, 0.00106ms
Run instruction: main:@740 = gpu::gather[axis=0](main:@737,main:@721,main:@739) -> float_type, {5940000, 4}, {4, 1}
Memory access fault by GPU node-1 (Agent handle: 0x113f180) on address 0x7f43847a0000. Reason: Page not present or supervisor privilege.

Today's dumps.

resnext50_logging.txt (without compiles)

resnext50_logging_compiles.txt

TedThemistokleous commented 1 year ago

Remembering what @umangyadav mentioned about this with reshapes, took a look at the simplify_reshapes() section of our passes found this matcher which got me curious:

struct find_nop_reshapes
{
    auto matcher() const
    {
        auto reshapes = reshaper_names();
        reshapes.insert("as_shape");
        reshapes.insert("broadcast");
        reshapes.insert("concat");
        reshapes.insert("convert");
        reshapes.insert("multibroadcast");
        reshapes.insert("pad");
        reshapes.insert("slice");
        reshapes.insert("transpose");
        return match::name(reshapes)(match::same_shape(match::arg(0)));
    }

    void apply(module& m, const match::matcher_result& mr) const
    {
        auto ins = mr.result;
        m.replace_instruction(ins, ins->inputs().front());
    }
};

It appears most of the sections in the network use logical_xor almost always with a convert prior

main:@2022 = convert[target_type=0](main:@2021) -> bool_type, {23760000}, {1}
main:@2023 = less(main:@1996,main:@377) -> int64_type, {1}, {0}
main:@2024 = convert[target_type=0](main:@2023) -> bool_type, {1}, {0}
main:@2025 = multibroadcast[out_lens={23760000}](main:@2024) -> bool_type, {23760000}, {0}
main:@2026 = logical_xor(main:@2022,main:@2025) -> bool_type, {23760000}, {1}

This makes me wonder, if one of the inputs was a binary zero after conversion this matcher would remove the convert with the inputs from the previous op since we're checking on the arg to said convert.

TedThemistokleous commented 1 year ago

resnext50_friday_debug2.txt

Had to update an issue with gpu_from with the help of Paul.

Looks like we're having an issue with gather still. Posted here to take a look/save.

Currently running this with

bin/driver perf ../resnext50_32x4d_fpn.onnx --fill1 images --input-dim @images 1 3 800 800 --output-names @boxes @labels @scores --enable-offload-copy -n 1

and the error I'm seeing is

Run instruction: main:@729 = gpu::gather[axis=0](main:@727,main:@726,main:@728) -> float_type, {116424}, {1}
Time: 0.01203ms, 0.080641ms
Output has zero, normal
Output: 92658, 58601, 59393, 74970, 75233, ..., 0, 0, 0, 0, 0
Run instruction: main:@730 = load[offset=534573952,end=540513952](main:@1) -> float_type, {1485000}, {1}
Time: 0.00094ms, 0.00116ms
Run instruction: main:@731 = gpu::code_object[code_object=13784,symbol_name=equal_convert_convert_not_less_convert_convert_logical_xor_mul_add_logical_and_where_kernel,global=371250,local=1024,](main:@707,main:@681,main:@730) -> float_type, {1485000}, {1}
Time: 0.005931ms, 0.135103ms
Output has zero
Output: 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0
Run instruction: main:@732 = get_tuple_elem[index=1](main:@723) -> int64_type, {401544}, {1}
Time: 0.0027ms, 0.00297ms
Output has zero, normal
Output: 1, 2, 3, 4, 5, ..., 401540, 401541, 401542, 401543, 0
Run instruction: main:@733 = load[offset=124740000,end=126346176](main:@1) -> float_type, {401544}, {1}
Time: 0.00203ms, 0.00224ms
Run instruction: main:@734 = gpu::code_object[code_object=13632,symbol_name=convert_kernel,global=100386,local=1024,](main:@732,main:@733) -> float_type, {401544}, {1}
Time: 0.01038ms, 0.086501ms
Output has zero, normal
Output: 1, 2, 3, 4, 5, ..., 401540, 401541, 401542, 401543, 0
Run instruction: main:@735 = squeeze[axes={1}](main:@719) -> float_type, {401544}, {1}
Time: 0.0015ms, 0.00171ms
Output has zero
Output: 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0
Run instruction: main:@736 = load[offset=841448608,end=843054784](main:@1) -> float_type, {401544}, {1}
Time: 0.00124ms, 0.00145ms
Run instruction: main:@737 = gpu::gather[axis=0](main:@735,main:@734,main:@736) -> float_type, {401544}, {1}
Time: 0.00977ms, 0.099672ms
Output has zero
Output: 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0
Run instruction: main:@738 = hip::hip_copy_literal[id=main:@literal:15] -> float_type, {5625, 4}, {4, 1}
Time: 0.00336ms, 0.00359ms
Output has zero, normal
Output: -91, -45, 91, 45, -114, ..., 882, 696, 624, 840, 912
Run instruction: main:@739 = load[offset=564739648,end=659779648](main:@1) -> float_type, {23760000}, {1}
Time: 0.00112ms, 0.00134ms
Run instruction: main:@740 = gpu::code_object[code_object=13784,symbol_name=equal_convert_convert_not_less_convert_convert_logical_xor_mul_add_logical_and_where_kernel,global=5940000,local=1024,](main:@709,main:@687,main:@739) -> float_type, {23760000}, {1}
Time: 0.007681ms, 1.06024ms
Output has normal, zero
Output: 0, 0, 0, 0, 0, ..., 0, 0, 0, 3.34356e-06, 0
Run instruction: main:@741 = load[offset=124740000,end=125205696](main:@1) -> float_type, {116424}, {1}
Time: 0.004ms, 0.00428ms
Run instruction: main:@742 = gpu::code_object[code_object=13624,symbol_name=mod_kernel,global=58212,local=1024,](main:@729,main:@741) -> float_type, {116424}, {1}
Time: 0.023271ms, 1.45881ms
Output has zero, normal
Output: 258, 257, 257, 258, 257, ..., 0, 0, 0, 0, 0
Run instruction: main:@743 = load[offset=540513952,end=564273952](main:@1) -> float_type, {5940000}, {1}
Time: 0.00102ms, 0.00125ms
Run instruction: main:@744 = gpu::code_object[code_object=13784,symbol_name=equal_convert_convert_not_less_convert_convert_logical_xor_mul_add_logical_and_where_kernel,global=1485000,local=1024,](main:@713,main:@703,main:@743) -> float_type, {5940000}, {1}
Time: 0.00675ms, 0.317406ms
Output has zero
Output: 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0
Run instruction: main:@745 = hip::hip_copy_literal[id=main:@literal:70] -> float_type, {22500, 4}, {4, 1}
Time: 0.00507ms, 0.00535ms
Output has zero, normal
Output: -45, -23, 45, 23, -57, ..., 841, 748, 713, 820, 855
Run instruction: main:@746 = load[offset=815825824,end=839585824](main:@1) -> float_type, {1485000, 4}, {4, 1}
Time: 0.00289ms, 0.00311ms
Run instruction: main:@747 = gpu::gather[axis=0](main:@738,main:@731,main:@746) -> float_type, {1485000, 4}, {4, 1}
Time: 0.03306ms, 0.434697ms
Output has normal
Output: -91, -45, 91, 45, -91, ..., 45, -91, -45, 91, 45
Run instruction: main:@748 = load[offset=720785824,end=815825824](main:@1) -> float_type, {5940000, 4}, {4, 1}
Time: 0.00406ms, 0.00435ms
Run instruction: main:@749 = gpu::gather[axis=0](main:@745,main:@744,main:@748) -> float_type, {5940000, 4}, {4, 1}
Memory access fault by GPU node-1 (Agent handle: 0x1bd7220) on address 0x7fc4d7da0000. Reason: Page not present or supervisor privilege.
TedThemistokleous commented 1 year ago

Few updates on this

-Found the issue during eval was related to gpu:gather calls related to fast_div. We decided it would be better to implement jit based gather. Implemented by PR #1492

Looks like I'll need to move nonmaxsuppression to the gpu somehow.

TedThemistokleous commented 1 year ago

update on this using linux perf, right now perf top

Image

It seems like running this with the --gpu flag has the system just spinning on batch_box, in compute_nms. Not sure why that would be the case here. I'll leave this running overnight and see if there;s a change in the morning to what else is running.

TedThemistokleous commented 1 year ago

So an update on this:

GPU side still runs forever. Perf seems to also indicate we're sort of "spinning" when running this on GPU. Originally I thought this was related to limits, and the while() loop in the compute_nms call so I added some probing.

    template <class Output, class Boxes, class Scores>
    std::size_t compute_nms(Output output,
                            Boxes boxes,
                            Scores scores,
                            const shape& max_output_shape,
                            std::size_t max_output_boxes_per_class,
                            double iou_threshold,
                            double score_threshold) const
    {
        std::fill(output.begin(), output.end(), 0);
        const auto& lens       = scores.get_shape().lens();
        const auto num_batches = lens[0];
        const auto num_classes = lens[1];
        const auto num_boxes   = lens[2];
        // boxes of a class with NMS applied [score, index]
        std::vector<std::pair<double, int64_t>> selected_boxes_inside_class;
        std::vector<int64_t> selected_indices;
        selected_boxes_inside_class.reserve(max_output_shape.elements());
        // iterate over batches and classes
        shape comp_s{shape::double_type, {num_batches, num_classes}};
        shape_for_each(comp_s, [&](auto idx) {
            auto batch_idx = idx[0];
            auto class_idx = idx[1];
            // index offset for this class
            auto scores_start = scores.begin() + (batch_idx * num_classes + class_idx) * num_boxes;
            // iterator to first value of this batch
            auto batch_boxes_start = boxes.begin() + batch_idx * num_boxes * 4;
            auto boxes_heap = filter_boxes_by_score(scores_start, num_boxes, score_threshold);
            selected_boxes_inside_class.clear();
            // Get the next box with top score, filter by iou_threshold

            while(not boxes_heap.empty() &&
                  selected_boxes_inside_class.size() < max_output_boxes_per_class)
            {
                std::cout << "heap size=" << boxes_heap.size() << std::endl;

                // Check with existing selected boxes for this class, remove box if it
                // exceeds the IOU (Intersection Over Union) threshold
                 const auto next_top_score = boxes_heap.top();
                bool not_selected =
                    std::any_of(selected_boxes_inside_class.begin(),
                                selected_boxes_inside_class.end(),
                                [&](auto selected_index) {
                                    return this->suppress_by_iou(
                                        batch_box(batch_boxes_start, next_top_score.second),
                                        batch_box(batch_boxes_start, selected_index.second),
                                        iou_threshold);
                                });

                if(not not_selected)
                {
                    selected_boxes_inside_class.push_back(next_top_score);
                    selected_indices.push_back(batch_idx);
                    selected_indices.push_back(class_idx);
                    selected_indices.push_back(next_top_score.second);
                    std::cout << "Not selected" << std::endl;
                }

                boxes_heap.pop();
                std::cout << "Pop heap size=" << boxes_heap.size() << std::endl;
            }
        });
    }

What's interesting is the max_output_boxes_per_class, i would expect to be quite low to exit this loop, but I'm seeing HUGE numbers for this threshold to be reached (9223372036854775808) which is basically the 64rd bit in size_t used and passed in by the arg(), so that section of the code seems to do nothing.

On the CPU side, I'm able to see things drain regularly to zero but on the GPU run, for some reason we're getting into the "not selected" case after every single top box before the top is popped off.

heap size=31702968
Not selected
Pop heap size=31702967
heap size=31702967
Not selected
Pop heap size=31702966
heap size=31702966
Not selected
Pop heap size=31702965
heap size=31702965
Not selected
Pop heap size=31702964
heap size=31702964
Not selected
Pop heap size=31702963
heap size=31702963
Not selected
Pop heap size=31702962
heap size=31702962

What is interesting is that removing the center logic

                // Check with existing selected boxes for this class, remove box if it
                // exceeds the IOU (Intersection Over Union) threshold
                 const auto next_top_score = boxes_heap.top();
                bool not_selected =
                    std::any_of(selected_boxes_inside_class.begin(),
                                selected_boxes_inside_class.end(),
                                [&](auto selected_index) {
                                    return this->suppress_by_iou(
                                        batch_box(batch_boxes_start, next_top_score.second),
                                        batch_box(batch_boxes_start, selected_index.second),
                                        iou_threshold);
                                });

                if(not not_selected)
                {
                    selected_boxes_inside_class.push_back(next_top_score);
                    selected_indices.push_back(batch_idx);
                    selected_indices.push_back(class_idx);
                    selected_indices.push_back(next_top_score.second);
                    std::cout << "Not selected" << std::endl;
                }

On the gpu run, actually got me to run through the network albiet I believe this isn't correct. I got the following output for a single run in terms of timing

Summary:
nonmaxsuppression: 69473.1ms / 1 = 69473.1ms, 74%
gpu::topk: 24443.2ms / 5 = 4888.64ms, 26%
hip::copy_from_gpu: 163.709ms / 4 = 40.9272ms, 1%
gpu::code_object::reduce_kernel: 141.387ms / 1 = 141.387ms, 1%
gpu::nonzero: 129.067ms / 5 = 25.8135ms, 1%
hip::copy_to_gpu: 60.5253ms / 1 = 60.5253ms, 1%
gpu::miopen_fusion: 23.7921ms / 49 = 0.485552ms, 1%
gpu::convolution: 23.1177ms / 53 = 0.436183ms, 1%
gpu::code_object::gather_kernel: 15.8803ms / 50 = 0.317606ms, 1%
gpu::code_object::concat_kernel: 10.9319ms / 20 = 0.546597ms, 1%
gpu::code_object::mul_add_kernel: 6.36609ms / 21 = 0.303147ms, 1%
gpu::code_object::convert_kernel: 5.04451ms / 12 = 0.420376ms, 1%
gpu::code_object::max_min_kernel: 4.78901ms / 10 = 0.478901ms, 1%
gpu::code_object::sub_kernel: 3.47368ms / 20 = 0.173684ms, 1%
gpu::code_object::min_exp_mul_mul_kernel: 3.38428ms / 10 = 0.338428ms, 1%
gpu::code_object::add_relu_kernel: 3.02912ms / 37 = 0.0818682ms, 1%
gpu::code_object::concat_add_kernel: 2.90338ms / 1 = 2.90338ms, 1%
gpu::code_object::add_kernel: 2.11672ms / 17 = 0.124513ms, 1%
gpu::code_object::contiguous_kernel: 1.56876ms / 14 = 0.112055ms, 1%
gpu::code_object::add_add_relu_kernel: 1.33128ms / 12 = 0.11094ms, 1%
gpu::code_object::equal_convert_convert_not_less_convert_convert_logical_xor_mul_add_logical_and_where_kernel: 1.07019ms / 5 = 0.214039ms, 1%
gpu::code_object::gathernd_kernel: 0.899728ms / 5 = 0.179946ms, 1%
gpu::code_object::mod_kernel: 0.658113ms / 5 = 0.131623ms, 1%
gpu::code_object::sigmoid_kernel: 0.655344ms / 5 = 0.131069ms, 1%
gpu::code_object::concat_mul_kernel: 0.50621ms / 1 = 0.50621ms, 1%
load: 0.489663ms / 420 = 0.00116586ms, 1%
gpu::code_object::greater_convert_kernel: 0.484339ms / 5 = 0.0968678ms, 1%
gpu::pooling: 0.253805ms / 1 = 0.253805ms, 1%
hip::hip_copy_literal: 0.240747ms / 154 = 0.00156329ms, 1%
slice: 0.087821ms / 46 = 0.00190915ms, 1%
multibroadcast: 0.080563ms / 56 = 0.00143863ms, 1%
gpu::code_object::mul_kernel: 0.080271ms / 4 = 0.0200678ms, 1%
broadcast: 0.068061ms / 53 = 0.00128417ms, 1%
unsqueeze: 0.062061ms / 60 = 0.00103435ms, 1%
step: 0.048481ms / 33 = 0.00146912ms, 1%
reshape: 0.043221ms / 36 = 0.00120058ms, 1%
get_tuple_elem: 0.033951ms / 10 = 0.0033951ms, 1%
gpu::code_object::relu_kernel: 0.023551ms / 1 = 0.023551ms, 1%
squeeze: 0.01998ms / 12 = 0.001665ms, 1%
transpose: 0.019701ms / 15 = 0.0013134ms, 1%
flatten: 0.00711ms / 5 = 0.001422ms, 1%
@param: 0.00615ms / 4 = 0.0015375ms, 1%
hip::sync_stream: 0.005561ms / 1 = 0.005561ms, 1%
check_context::migraphx::version_1::gpu::context: 0.00423ms / 1 = 0.00423ms, 1%
hip::hip_allocate_memory: 0.00328ms / 1 = 0.00328ms, 1%

Batch size: 1
Rate: 0.00986677/sec
Total time: 101350ms
Total instructions time: 94524.5ms
Overhead time: 0.758245ms, 6825.73ms
Overhead: 0%, 7%

I know the data output isn't correct per say since I've modified the NMS operator but interestingly enough, seeing topK being a huge contributor to the time is also concerning. From what I've found, it looks like constantly searching + not picking and then pushing back top scores is just eating a bunch of bandwidth and is the cause for the bottle neck.

What has me curious is WHY things are different between CPU and GPU here. I'll print out scores and compare CPU/GPU to see what the results are for this. I'll also use fill0 on the inputs to make sure we have the same data between runs.

TedThemistokleous commented 1 year ago

At the point where I'm debugging why non max suppression is stalling so much via trims. Had to modify some older code I had with parse if, which allows us to clean up submodules generated by the IF op and each branch.

will be doing trim, reductions and verify to determine where the failure is happening in the GPU implementation of the resnext50 model now

TedThemistokleous commented 1 year ago

Further debugging using verify and trims yielded more questions so I began running a per-instruction verify and ran into a segfault with a gather.

Currently debugging this as party of the issue with verify which seems to fail on a gather of two literals as input.

I've created another testcase found here which captures the current bug and will be patched for this case:

https://github.com/ROCmSoftwarePlatform/AMDMIGraphX/blob/bugfix_gather_jit_double_literal_input/test/verify/test_gather_literal_inputs.cpp

This seems to occur after our optimize passes

with the help of Umang, I turned off pointwise fusions via MIGRAPHX_DISABLE_POINTWISE_FUSIONS=1 to get this log readout with lldb-12. Still making sense of things.

debug_gather_kernel_literal_ins.txt

TedThemistokleous commented 1 year ago

Currently able to get past gather piece thats failing as well as issues with null pointwise described in #1622 and #1587

Running inference seems to still break nms and stall.

Running a driver verify with the changes from the mentioned PRs I get broken output for slice:

In file included from main.cpp:2:
./migraphx/kernels/index.hpp:95:9: error: static assertion failed due to requirement '0 > 0': Global size must be greater than 0
        static_assert(MIGRAPHX_NGLOBAL > 0, "Global size must be greater than 0");
        ^             ~~~~~~~~~~~~~~~~~~~~
./migraphx/kernels/index.hpp:136:28: error: division by zero is undefined [-Werror,-Wdivision-by-zero]
        return (n - _c<1>) / stride + _c<1>;
                           ^ ~~~~~~
./migraphx/kernels/index.hpp:193:26: note: in instantiation of function template specialization 'migraphx::index::max_stride_iterations<migraphx::integral_constant<unsigned int, 0>, migraphx::integral_constant<unsigned int, 0>>' requested here
            if constexpr(max_stride_iterations(n, stride) == 1)
                         ^
./migraphx/kernels/index.hpp:226:9: note: in instantiation of function template specialization 'migraphx::index::for_stride<false, (lambda at ./migraphx/kernels/pointwise.hpp:40:23), migraphx::integral_constant<unsigned int, 0>, migraphx::integral_constant<unsigned int, 0>>' requested here
        for_stride<false>(global, n, nglobal(), f);
        ^
./migraphx/kernels/pointwise.hpp:39:9: note: in instantiation of function template specialization 'migraphx::index::global_stride<(lambda at ./migraphx/kernels/pointwise.hpp:40:23), migraphx::integral_constant<unsigned int, 0>>' requested here
    idx.global_stride(out.get_shape().elements(),
        ^
./migraphx/kernels/pointwise.hpp:48:36: note: in instantiation of function template specialization 'migraphx::pointwise_tensor<(lambda at main.cpp:23:39), migraphx::tensor_view<long __attribute__((ext_vector_type(2))), migraphx::shape<migraphx::integral_const_array<unsigned int, 0>, migraphx::integral_const_array<unsigned int, 1>>>, migraphx::tensor_view<float __attribute__((ext_vector_type(2))), migraphx::shape<migraphx::integral_const_array<unsigned int, 0>, migraphx::integral_const_array<unsigned int, 1>>>>' requested here
        t(ps...)([&](auto... xs) { pointwise_tensor(idx, f, xs...); });
                                   ^
In file included from main.cpp:2:
./migraphx/kernels/index.hpp:193:26: error: constexpr if condition is not a constant expression
            if constexpr(max_stride_iterations(n, stride) == 1)
                         ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
./migraphx/kernels/index.hpp:226:9: note: in instantiation of function template specialization 'migraphx::index::for_stride<false, (lambda at ./migraphx/kernels/pointwise.hpp:40:23), migraphx::integral_constant<unsigned int, 0>, migraphx::integral_constant<unsigned int, 0>>' requested here
        for_stride<false>(global, n, nglobal(), f);
        ^
./migraphx/kernels/pointwise.hpp:39:9: note: in instantiation of function template specialization 'migraphx::index::global_stride<(lambda at ./migraphx/kernels/pointwise.hpp:40:23), migraphx::integral_constant<unsigned int, 0>>' requested here
    idx.global_stride(out.get_shape().elements(),
        ^
./migraphx/kernels/pointwise.hpp:48:36: note: in instantiation of function template specialization 'migraphx::pointwise_tensor<(lambda at main.cpp:23:39), migraphx::tensor_view<long __attribute__((ext_vector_type(2))), migraphx::shape<migraphx::integral_const_array<unsigned int, 0>, migraphx::integral_const_array<unsigned int, 1>>>, migraphx::tensor_view<float __attribute__((ext_vector_type(2))), migraphx::shape<migraphx::integral_const_array<unsigned int, 0>, migraphx::integral_const_array<unsigned int, 1>>>>' requested here
        t(ps...)([&](auto... xs) { pointwise_tensor(idx, f, xs...); });
                                   ^
./migraphx/kernels/index.hpp:136:28: note: division by zero
        return (n - _c<1>) / stride + _c<1>;
                           ^
./migraphx/kernels/index.hpp:193:26: note: in call to 'max_stride_iterations({}, {})'
            if constexpr(max_stride_iterations(n, stride) == 1)
                         ^
3 errors generated when compiling for gfx1030.
Instruction slice threw an exception.
terminate called after throwing an instance of 'migraphx::version_1::exception'
  what():  /code/AMDMIGraphX/src/compile_src.cpp:71: compile: Output file missing: main.o

Even on a read without these changes working off develop after every sigmoid it appears there's an invalid shape used for slice after every sigmoid option in the network. This was taken from doing a driver read

main:@2008 = gather[axis=0](main:@403,main:@402) -> int64_type, {1}, {0}
main:@2009 = sigmoid(main:@2003) -> float_type, {90000, 264}, {264, 1}
main:@2010 = slice[axes={0},starts={0},ends={0}](main:@401) -> int64_type, {0}, {1}
main:@2011 = concat[axis=0](main:@2010,main:@397) -> int64_type, {1}, {1}
main:@2012 = reshape[dims={-1}](main:@2009) -> float_type, {23760000}, {1}

To sanity check, taking a look at netron I've found the following:

image

What's very interesting here is when I run a compiled output, to look at the result with --gpu that invalid shape seems to disappear entirely

main:@651 = gpu::code_object[code_object=13632,symbol_name=sigmoid_kernel,global=5940000,local=1024,](main:@649,main:@650) -> float_type, {90000, 264}, {264, 1}
main:@652 = load[offset=80014176,end=85954176](main:@1) -> float_type, {1, 1485000}, {1485000, 1}

It looks like during the compile we go from the invalid slice case (start=end) case to adding a reshape instead and removing the slice entirely.

main:@1227 = sigmoid(main:@1221) -> float_type, {90000, 264}, {264, 1}
main:@1228 = reshape[dims={-1}](main:@1227) -> float_type, {23760000}, {1}
main:@1229 = multibroadcast[out_lens={23760000},out_dyn_dims={}](main:@118) -> float_type, {23760000}, {0}

It looks like it creates a pointwise module but there's also an odd point in the log compile output where the pointwise logging gets garbled.

module: "main:pointwise428"
x0 = @param:x0 -> float_type, {1}, {0}
main:pointwise428:@1 = convert[target_type=9](x0) -> int64_type, {1}, {0}
main:pointwise428:@2 = @return(main:pointwise428:@1)

x2x1 = @param:x1 -> float_type, { = @param:x2 -> float_type, {x2 = @param:x2 -> float_type, {1}, {0}
x0 = @param:x0 -> float_type, {1}, {0}
1}, {0}
main:pointwise8:@2 = x1 = @param:x1 -> x1 = @paramfloat_type:, x1 -> float_type, {1}, {10}, }x1{ = @param:x1 -> 0}
x1x1x1{add = @param:x1 -> 1}, {x0 = @param:x0 -> float_type0, }{
 = x0@param1float_type},  = , {:{ = @param
@param:x1 -> x1float_type, {:x1 =  -> @paramx00float_type:} -> float_type, {, x1
1}, {{float_type -> float_type, {0, {(1x1 = @param:x1 -> float_type, {x01}, { = @param:x00 -> }x2float_type, {
 = @param1}}, x01
{ = main:pointwise21:@3x0}, @param1 = add0(},:}, {:0}}, 

x2{x110}
x0 = @param:x0 -> float_type, {1}, {0}
x0 = @param:x0 -> float_type, {1}, {0}
main:pointwise96:@2 = add(x0,x1) -> float_type, {1}, {0}
main:pointwise96:@3 = x0relu(main:pointwise96:@2) -> float_type, { = @param:x01 -> )float_type},  -> {,  -> 0}
x0x0}, x0 = @param:x0 -> float_type, {1}, {0}
main:pointwise57:@2 = add(x0,x1) -> float_type, {1}, {0}
main:pointwise57:@3 = relu(main:pointwise57:@2) -> float_type, {1}, {0}
1main:pointwise57:@4 = @return(main:pointwise57:@3)

{x1 = @param:x1 -> float_type, {1}, { -> 0}
float_typex0 = @param, x2float_type{ = ,@param:, 1}, {{x20main:pointwise32:@2}, float_type{} ->  = 1main:pointwise96:@4
{0x1 = 1}, {0}
x1add}
x0 = @param:x0 -> float_type, {1}, {0}
main:pointwise85:@2 = add(x0,x1) -> main:pointwise0:@3float_type, {1}, { = mul(0x0},x1)}, main:pointwise248:@0float_type
{ = @literal{x0 = @param:x0 -> 0float_type, {}
main:pointwise209:@0 = main:pointwise71:@2@literal = {add = (, 0.00378788 -> }{1}, {(0@return0}}main:pointwise85:@3x00.05
:(1float_type,}x1 -> @param, float_type, {x1 = @param:{:x0x1 -> }, x1main:pointwise272:@01main:pointwise96:@3float_type, { -> x0 = ) -> { = float_type1, ,x1) -> float_type, {main:pointwise203:@11}, { -> float_type
@literal{, }, }, { = sigmoid -> 1}, {10float_type{0.5{00float_type, {, 1}, {x000{}
})main:pointwise8:@3})} ->  =  -> 
}{}
1main:pointwise57:@2float_typerelu1
{,  = }, (
relu
}, x0@param}, 
0{:{}{main:pointwise46:@20x0{x1float_type(

Not sure here if that contributes or is due to the error with slice or something else.

The log output of just doing a compile off develop is found here: debug_resnext50_compile.log

log output of read is found here: debug_resnext50_read.log

The output with the fixes to gather specified in #1622 and #1587 thats triggered by slice is found here debug_resnext50_verify.log

TedThemistokleous commented 1 year ago

After comparing things with respect to onnxruntime's output by modifying the accuracy_checker script to perform onnxrt inferances. it appears Onnxrt optimizes out the invalid slice as well after optimizations

An interesting side note with resnext50, we can't seem to soley run this network on onnxruntime as we get into an error with one of the subgrabs within the nested IF block (IF_1905) without relying on fallback logic to the ROCMExecutionProvider which was added to the inference session provider list to perform the run. Something to look into later down the road on the MIGraphXEP side it seems.

For getting this to run on MIGraphX, looks like I'm able to get past slice by using the optimize flag in our driver and performing verify instructions with the following command up until that point. Had to adjust my trim offset after some trial an error with respect to how we perform an optimize in the driver to the onnx model.

bin/driver verify ../resnext50_32x4d_fpn.onnx -O -t 579 --gpu --per-instruction &> debug_resnext50_reduce.log

Should I run into any issues I'll begin using lldb-12 to further debug and get a proper backtrace overnight. Currently just running this out over the evening.

TedThemistokleous commented 1 year ago

We're finally getting runs on resnext50 thanks to @pfultz2 and hinding and fixing an issue with concat vectorization #1653 In the meantime i was trying to make our NMS multithreaded for a single huge batch thinking that was still the issue as we worked in parallel.

Adding both of these fixes ontop of each-other gives us the following result

Summary:
gpu::topk: 25017.3ms / 5 = 5003.47ms, 85%
nonmaxsuppression: 3860.93ms / 1 = 3860.93ms, 14%
hip::copy_from_gpu: 183.658ms / 4 = 45.9146ms, 1%
gpu::code_object::reduce_kernel: 143.436ms / 1 = 143.436ms, 1%
gpu::nonzero: 127.553ms / 5 = 25.5106ms, 1%
hip::copy_to_gpu: 70.1181ms / 1 = 70.1181ms, 1%
gpu::miopen_fusion: 23.6265ms / 49 = 0.482174ms, 1%
gpu::convolution: 20.3926ms / 53 = 0.384766ms, 1%
gpu::code_object::gather_kernel: 15.7473ms / 50 = 0.314946ms, 1%
gpu::code_object::concat_kernel: 9.74119ms / 19 = 0.512694ms, 1%
gpu::code_object::mul_add_kernel: 6.44609ms / 21 = 0.306956ms, 1%
gpu::code_object::convert_kernel: 4.20801ms / 12 = 0.350667ms, 1%
gpu::code_object::max_min_kernel: 3.53736ms / 10 = 0.353737ms, 1%
gpu::code_object::sub_kernel: 3.49076ms / 20 = 0.174538ms, 1%
gpu::code_object::min_exp_mul_mul_kernel: 3.35976ms / 10 = 0.335976ms, 1%
gpu::code_object::add_relu_kernel: 3.00536ms / 37 = 0.0812258ms, 1%
gpu::code_object::concat_add_kernel: 2.82924ms / 1 = 2.82924ms, 1%
gpu::code_object::add_kernel: 2.11693ms / 17 = 0.124525ms, 1%
gpu::code_object::contiguous_kernel: 1.44307ms / 10 = 0.144307ms, 1%
gpu::code_object::add_add_relu_kernel: 1.2778ms / 12 = 0.106483ms, 1%
gpu::code_object::concat_mod_kernel: 0.97079ms / 1 = 0.97079ms, 1%
gpu::code_object::less_convert_convert_logical_xor_mod_equal_convert_convert_not_logical_and_mul_add_where_kernel: 0.739805ms / 5 = 0.147961ms, 1%
gpu::code_object::gathernd_kernel: 0.723255ms / 5 = 0.144651ms, 1%
gpu::code_object::sigmoid_kernel: 0.655794ms / 5 = 0.131159ms, 1%
gpu::code_object::greater_convert_kernel: 0.479331ms / 5 = 0.0958662ms, 1%
gpu::code_object::mul_kernel: 0.47863ms / 1 = 0.47863ms, 1%
load: 0.365675ms / 407 = 0.000898464ms, 1%
gpu::pooling: 0.250205ms / 1 = 0.250205ms, 1%
hip::hip_copy_literal: 0.146673ms / 147 = 0.000997776ms, 1%
slice: 0.076922ms / 46 = 0.00167222ms, 1%
unsqueeze: 0.06204ms / 60 = 0.001034ms, 1%
broadcast: 0.060922ms / 53 = 0.00114947ms, 1%
multibroadcast: 0.055462ms / 56 = 0.000990393ms, 1%
step: 0.0426ms / 33 = 0.00129091ms, 1%
reshape: 0.035711ms / 36 = 0.000991972ms, 1%
gpu::code_object::relu_kernel: 0.0212ms / 1 = 0.0212ms, 1%
get_tuple_elem: 0.02079ms / 10 = 0.002079ms, 1%
squeeze: 0.019261ms / 12 = 0.00160508ms, 1%
transpose: 0.01602ms / 15 = 0.001068ms, 1%
@param: 0.006191ms / 4 = 0.00154775ms, 1%
flatten: 0.00549ms / 5 = 0.001098ms, 1%
hip::sync_stream: 0.00529ms / 1 = 0.00529ms, 1%
check_context::migraphx::version_2_6_0::gpu::context: 0.00369ms / 1 = 0.00369ms, 1%
hip::hip_allocate_memory: 0.00231ms / 1 = 0.00231ms, 1%

Batch size: 1
Rate: 0.0331548/sec
Total time: 30161.5ms
Total instructions time: 29509.5ms
Overhead time: 0.636814ms, 652.028ms
Overhead: 0%, 2%

We still have an issue with accuracy when using the accuracy checker and we're currently seeing an HSA_FAULT when running with MIGRAPHX_GPU_DEBUG =1

./migraphx/kernels/gather.hpp:59: operator(): error: Out of bounds access at offset: 23760000
:0:rocdevice.cpp            :2647: 4812524667207 us: 3344692: [tid:0x7f5677d9a700] Device::callbackQueue aborting with error : HSA_STATUS_ERROR_EXCEPTION: An HSAIL operation

which is something I'm looking into between reviews

TedThemistokleous commented 1 year ago

Looks like theres a few errors and issues still when trying to get an accurate output but for now we get baseline ref and GPU runs so I'm closing this out and moving further debugging changes to #1698