ROCm / AMDMIGraphX

AMD's graph optimization engine.
https://rocm.docs.amd.com/projects/AMDMIGraphX/en/latest/
MIT License
184 stars 82 forks source link

MLIR conv fails verify when padding is not equal in all dims #2407

Open shivadbhavsar opened 11 months ago

shivadbhavsar commented 11 months ago

Here is a minimal example to reproduce this issue:

  1. Generate mxr by running the following python script:
    
    import migraphx
    import numpy as np

if name == "main": p = migraphx.program() mm = p.get_main_module() inp = mm.add_parameter("x", migraphx.shape(lens=[8, 3, 50, 50])) weights_np = np.random.randn(16, 3, 3, 3).astype(np.float32) w = mm.add_literal(weights_np) out_mgx = mm.add_instruction(migraphx.op('convolution', padding=[1, 2]), [inp, w]) mm.add_return([out_mgx]) migraphx.save(p, "conv_fail.mxr")

This will output "conv_fail.mxr" which contains this uncompiled program:

@0 = @literal{ ... } -> float_type, {16, 3, 3, 3}, {27, 9, 3, 1}, target_id=0 x = @param:x -> float_type, {8, 3, 50, 50}, {7500, 2500, 50, 1}, target_id=0 @2 = convolutionpadding={1, 2},stride={1, 1},dilation={1, 1},group=1,padding_mode=0 -> float_type, {8, 16, 50, 52}, {41600, 2600, 52, 1}, target_id=0 @3 = @return(@2), target_id=0


2. Use driver verify with --int8 flag:

**With MLIR**:

migraphx-driver verify conv_fail.mxr

Output:

FAILED: conv_fail.mxr RMS Error: 0.321727 Max diff: 22.0078 Mismatch at 0: 0.992833 != 1.84061


**Without MLIR**:

MIGRAPHX_DISABLE_MLIR=1 migraphx-driver verify conv_fail.mxr

Output:

MIGraphX verification passed successfully.



Note: I am using driver verify here to easily reproduce the issue, but I noticed this while running a test case in torch_migraphx. Initially I noticed this while running quantized kernels, but the above steps reproduce the bug with no quantization too. As far as I was able to see, its only the padding that's causing this issue. Padding attribute of (1,1) or (2,2) both pass verify and torch_migraphx tests.

EDIT: This seems to happen on navi systems but not on MI systems
pfultz2 commented 11 months ago

BTW, you can pass a python file directly to the driver instead of saving to an intermediate file ./bin/driver read conv_fail.py:

# conv_fail.py
import numpy as np

p = migraphx.program()
mm = p.get_main_module()
inp = mm.add_parameter("x", migraphx.shape(lens=[8, 3, 50, 50]))
weights_np = np.random.randn(16, 3, 3, 3).astype(np.float32)
w = mm.add_literal(weights_np)
out_mgx = mm.add_instruction(migraphx.op('convolution', padding=[1, 2]),
                                [inp, w])
mm.add_return([out_mgx])