ROCm / AMDMIGraphX

AMD's graph optimization engine.
https://rocm.docs.amd.com/projects/AMDMIGraphX/en/latest/
MIT License
187 stars 87 forks source link

migraphx does not supoort shape_dynamic,reshape_dynamic,constantofshape_dynamic and so on #1975

Closed qianqing13579 closed 1 month ago

qianqing13579 commented 1 year ago

the following ops do not support dynamic shape in migraphx.

I hope that migraphx can fully support dynamic shape as soon as possible.Thanks

qianqing13579 commented 1 year ago

Hello @CharlieL7 @umangyadav , We have many projects which need dynamic shape,and I have been developing dynamic shape of migraphx recently.My project is here. I want to share some ideas and the problems I encountered in the implementation. I hope we can solve the dynamic shape together.

The reasons why migraphx does not support dynamic shape can list as follows:

1. MIGraphX IR does not support dynamic models the following ops in MIGraphX IR can not support dynamic models

the output shapes of thess ops are depend on the value of input tensor which are called data dependent in TVM relax(https://github.com/tlc-pack/relax/wiki/Relax-Architecture-Overview). These ops should be extended to support dynamic shape.

2. MIGraphX compilation needs shape many passes in compilation need shape information,such as memory coloring ,fuse ops,but dynamic models can not provide shape information in compilation.

3. gpu kernels can not run with different shapes MIGraphX JIT kernel can not run with different shapes, such as pointwise jit kernels. what's more, some ops are replaced with other ops,such as resize and lstm,which can cause ops do not support dynamic shape.

4. MIGraphX can not compute shape in runtime dynamic models need compute the output shapes of each op according to input shapes in runtime.

My solutions about the dynamic shape problems are as follows:

  1. extend MIGraphX IR.
  2. set a max shape in parse_onnx.The shapes can not exceed the max shape in runtime. we can do many optimizations in compilation with a max shape in dynamic models,such as memory coloring
  3. rewrite jit implementation kernel, such as resize kernel
  4. use a shape function that can compute shape in runtime
My solutions can support many CV and NLP models with dynamic shape. support mode dynamic mode
ResNet50 N,H,W
InceptionV3 N,H,W
MobileNetV2 N,H,W
DenseNet N,H,W
MTCNN N,H,W
SSD-VGG16 N,H,W
RetinaNet N,H,W
RetinaFace N,H,W
YOLOV3 N,H,W
YOLOV5 N,H,W
DBNet N,H,W
FCN N,H,W
UNet N,H,W
CRNN-LSTM H,W
SVTR H,W
BERT dynamic sequence
Transformer dynamic sequence
GPT2 dynamic sequence

The problems in my solutions:

  1. The performance is poor. The shape related ops in my solutions are assigned to GPU which cause bad performance. These ops should be assigned to CPU to improve perf. The miopen conv performance in dynamic shape is poor.
  2. Can not support models if the input arguments of data dependent ops are model input params

A dynamic shape demo:

#include <migraphx/onnx.hpp>
#include <migraphx/gpu/target.hpp>

// Set a max shape
migraphx::onnx_options onnx_options;
onnx_options.map_input_dims["input"]={8,3,224,224};

// Load model
migraphx::program net= migraphx::parse_onnx("ResNet50.onnx",onnx_options);

// Compile
migraphx::compile_options options;
net.compile(migraphx::gpu::target{},options);

// Set different shapes in runtime
std::vector<std::vector<std::size_t>> inputShapes;
inputShapes.push_back({1,3,112,112});
inputShapes.push_back({2,3,224,224});

// Run inference with  different shapes
for(int i=0;i<inputShapes.size();++i)
{
   // Load and preprocess the input image to inputTensor with different shapes
   ...

   // Run inference
  std::vector<migraphx::argument> results = net.eval(inputTensor);

...

}
CharlieL7 commented 1 year ago

Some of the recent PRs address several of your points, like the Shape ONNX operator. As a whole, full dynamic shape support still has significant development required.

qianqing13579 commented 1 year ago

@CharlieL7 ,Thanks for sharing your ideas. I have a question, if dynamic batch is implemented by creating static shape submodules for each of the batch sizes,the memory usage will increase because we need save static shape submodules for each shape in memory, and if the value range of a dynamic shape is large, the memory may overflow.

CharlieL7 commented 1 year ago

@qianqing13579 Yes, storage of compile-time constants is a valid concern with how dynamic batch is currently implemented in MIGX. There's a compiler pass promote_literals that will move literals to the main module. That allows repeated literals to be optimized out. I'm assuming most users will only be interested in a couple of batch sizes in practice. We will add functionality to make it possible to set the specific batch sizes to support.

The way dynamic batch is implemented it should retain the performance of the static batch models. To fully support dynamic shapes in models with data-dependent shape functions (like single-shot detector models), we would definitely lose performance.

CharlieL7 commented 1 year ago

@qianqing13579 I've looked through your fork. The way you've implemented the dynamic shape support is similar but will not directly work with how we've planned to extend support for full dynamic shape support. I think we might be able to take in pieces of your work in for dynamic shape support. Let me know if you want to discuss the design or plan for full dynamic shape support.

qianqing13579 commented 1 year ago

@CharlieL7 ,I want to discuss the design and plan for full dynamic shape support . Recently I have been refactoring the project and cleaning up the code

lzd-1230 commented 11 months ago

Hi guys, I've met problem with migraphx inferring the dynamic input onnx model. after I exporting the model

dynamic_ax = {'input' : {2 : 'image_height',3:'image_wdith'},   
                        'output' : {2 : 'image_height',3:'image_wdith'}}
torch.onnx.export(model, (x), onnx_model_path, 
       input_names=["input"], output_names=["output"], verbose=False, opset_version=11,dynamic_axes=dynamic_ax)

and save to .mxr by migraphx python API

model = migraphx.parse_onnx(model_path, map_dyn_input_dims = max_input) 
model.compile(t = migraphx.get_target("gpu"), device_id=0)
migraphx.save(model, mxr_model_path)

finally I inferr the model

model = migraphx.load(mxr_model_path)
inputName = model.get_parameter_names()[0]
pred = model.run({inputName: migraphx.argument(img.cpu().numpy())})     

but I got such problem, is this problem what you guys are discussing about shape_dynamic things? or it's another problem. if so, could you guys can give me some guidance about it

Invalid address access: 0x145d38390000, Error code: 1.
>>>>>>>> KERNEL VMFault !!!! <<<<<<
>>>>>>>> PID: 421225 !!!! <<<<<<
=========> STREAM <0x55ebde83ffc0>: VMFault HSA QUEUE ANALYSIS <=========
=========> STREAM <0x55ebde840090>: VMFault HSA QUEUE ANALYSIS <=========
=========> STREAM <0x55ebe42fd9b0>: VMFault HSA QUEUE ANALYSIS <=========
STREAM <0x55ebe42fd9b0>: >>>>>>>> DUMP KERNEL AQL PACKET <<<<<<<<<
STREAM <0x55ebe42fd9b0>: header: 770
STREAM <0x55ebe42fd9b0>: setup: 3
STREAM <0x55ebe42fd9b0>: workgroup: x:256, y:1, z:1
STREAM <0x55ebe42fd9b0>: grid: x:786432, y:1, z:1
STREAM <0x55ebe42fd9b0>: group_segment_size: 0
STREAM <0x55ebe42fd9b0>: private_segment_size: 0
STREAM <0x55ebe42fd9b0>: kernel_object: 22401955784704

SUCCESS: FIND SAME KERNEL OBJECT COMMAND IN USE LIST. useIdx: 0
STREAM <0x55ebe42fd9b0>: >>>>>>>> FIND MATCH KERNEL COMMAND <<<<<<<<<
STREAM <0x55ebe42fd9b0>: kernel name: _ZN8migraphx9version_13gpu6device8launcherIZZNS2_9mi_launchILj4EEEDaP12ihipStream_tRKNS2_9hip_shapeIXT_EEEjENKUlT_E_clIZZNS2_12mi_gs_launchILj4EEEDaS6_SA_jENKUlSB_E_clIZZZNS2_22contiguous_nonstandardES6_RKNS0_8argumentESJ_ENKUlSB_T0_E_clINS0_11tensor_viewIfEESO_EEDaSB_SK_ENKUlSB_SK_T1_E_clINS2_15hip_tensor_viewIfLj4EEEST_NS7_ILj4EEEEEDaSB_SK_SP_EUlSB_E_EEDaSB_EUlSB_SK_E_EEDaSB_EUlSB_E_EEvSB_
STREAM <0x55ebe42fd9b0>: >>>>>>>> DUMP KERNEL ARGS: size: 248 <<<<<<<<<

01 00 00 00 03 00 00 00 00 02 00 00 00 02 00 00 
00 00 0c 00 00 00 04 00 00 02 00 00 01 00 00 00 
00 00 00 00 00 04 00 00 56 55 55 55 55 01 00 00 
00 00 00 00 02 00 00 00 00 00 00 00 02 00 00 00 
01 fa ac e3 eb 55 00 00 01 00 00 00 00 00 00 00 
00 00 00 00 00 00 00 00 00 00 c0 d4 5c 14 00 00 
01 00 00 00 03 00 00 00 00 02 00 00 00 02 00 00 
00 00 0c 00 00 00 04 00 00 02 00 00 01 00 00 00 
00 00 00 00 00 04 00 00 56 55 55 55 55 01 00 00 
00 00 00 00 02 00 00 00 00 00 00 00 02 00 00 00 
01 ff d7 e8 5f 14 00 00 c0 09 20 38 5d 14 00 00 
01 00 00 00 03 00 00 00 40 01 00 00 aa 01 00 00 
80 3d 06 00 80 14 02 00 aa 01 00 00 01 00 00 00 
00 00 00 00 00 04 00 00 56 55 55 55 55 01 00 00 
34 33 33 33 03 00 00 00 6b 8b 5c 67 02 00 00 00 
01 bc 9d ea 5f 14 00 00 

STREAM <0x55ebe42fd9b0>: >>>>>>>> DUMP KERNEL ARGS PTR INFO <<<<<<<<<
STREAM <0x55ebe42fd9b0>: ptr arg index: 11, ptr: 0x145cd4c00000
STREAM <0x55ebe42fd9b0>: host ptr: 0x145cd0c00000, device ptr: 0x145cd0c00000, unaligned ptr: 0x145cd0c00000
STREAM <0x55ebe42fd9b0>: size byte: 342360065
STREAM <0x55ebe42fd9b0>: ptr arg index: 21, ptr: 0x145d382009c0
STREAM <0x55ebe42fd9b0>: host ptr: 0x55ebf57489c0, device ptr: 0x145d382009c0, unaligned ptr: 0x145d382009c0
STREAM <0x55ebe42fd9b0>: size byte: 1635840

.......

I am very grateful for your reply

qianqing13579 commented 11 months ago

@lzd-1230 ,you should use onnx model,and the program should set the max input shape. sample:

maxInput={"input":[8,3,224,224]}

model = migraphx.parse_onnx("ResNet50.onnx",map_input_dims=maxInput)

model.compile(t=migraphx.get_target("gpu"),device_id=0)

...// load data and infer
lzd-1230 commented 11 months ago

@lzd-1230 ,you should use onnx model,and the program should set the max input shape. sample:

maxInput={"input":[8,3,224,224]}

model = migraphx.parse_onnx("ResNet50.onnx",map_input_dims=maxInput)

model.compile(t=migraphx.get_target("gpu"),device_id=0)

...// load data and infer

Thanks for reply, I used exact the same as you mentioned onnx model and map_input_dims=maxInput (I'm typing wrong when asking question), and the error is the same. But the same code can safely infer in migraphx==4.0.0!
Is there any clue when concerning about difference between migraphx==2.5 and 4.0. And can I find a workaround in 2.5 env? Because I couldn't find a way to upgrad the migraphx version in my dev server for the special gpu provider doesn't fit the 4.0 version now.

CharlieL7 commented 11 months ago

@lzd-1230 It does not look like you are using dynamic batch or dynamic shapes for the ResNet50 model. Open up a new issue with the details of what you want to do. It's unclear to me what exactly you mean by the MIGraphX version and gpu provider.