iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.79k stars 604 forks source link

Compiling mobilenetv3 generates out of bounds access on one of the convolutions (in backward pass). #4706

Closed phoenix-meadowlark closed 2 years ago

phoenix-meadowlark commented 3 years ago

Attempting to compile part of MobileNetV3's update step produced the following fatal errors:

free(): invalid next size (normal)
[1]    1559527 abort      python mobilenetv3_no_batchnorm.py
double free or corruption (!prev)
[1]    1560126 abort      python mobilenetv3_no_batchnorm.py
corrupted double-linked list
[1]    1560870 abort      python mobilenetv3_no_batchnorm.py
python: malloc.c:3839: _int_malloc: Assertion `chunk_main_arena (bck->bk)' failed.
(no error, but it hung for several minutes, and now won't respond to any number of keyboard interrupts.)

This is the raw python code that produces the above errors. Sometimes it crashes before the C++ has finished invoking, sometimes it only crashes after the python program errors or attempts to exit successfully.

I also created mobilenetv3_reproducer.mlir via

iree.jax.aot(update, optimizer, batch, import_only=True,
             output_file="/tmp/mobilenetv3_reproducer.mlir")

From what I can tell this should be reproducible via this code

import pyiree as iree
import pyiree.compiler2
import pyiree.rt

import numpy as np
from numpy import dtype

binary = iree.compiler2.compile_file("/tmp/mobilenetv3_reproducer.mlir",
                                     target_backends=["dylib-llvm-aot"])
cpp_vm_module = iree.rt.VmModule.from_flatbuffer(binary)
module = iree.rt.load_module(cpp_vm_module, config=iree.rt.Config("dylib"))

signature = [
    ((), dtype('int32')),
    ((1, 1, 16, 16), dtype('float32')),
    ((1, 1, 16, 16), dtype('float32')),
    ((3, 3, 1, 16), dtype('float32')),
    ((8,), dtype('float32')),
    ((1, 1, 16, 8), dtype('float32')),
    ((16,), dtype('float32')),
    ((1, 1, 8, 16), dtype('float32')),
    ((1, 16, 16, 16), dtype('float32')),
    ((1,), dtype('int32')),
]
ndarange = lambda shape, dtype: np.arange(np.prod(shape), dtype=dtype).reshape(shape)
inputs = [ndarange(*args) for args in signature]

baseline = module.main(*inputs)
for i in range(1000):
  outputs = module.main(*inputs)
  for a, b in zip(baseline, outputs):
    np.testing.assert_equal(a, b)  # just doing something similar to what the original script does.

but this runs without any issue, so I'm at a bit of a loss. This is just a flattened version of what the @iree.jax.jit decorator does.

stellaraccident commented 3 years ago

Actually, the current failure appears to not be related to python bindings (or if so, oddly related):

#0  __GI_raise (sig=sig@entry=6) at ../sysdeps/unix/sysv/linux/raise.c:50
#1  0x00007ffff7c4f537 in __GI_abort () at abort.c:79
#2  0x00007ffff7ca8768 in __libc_message (action=action@entry=do_abort, fmt=fmt@entry=0x7ffff7db6e31 "%s\n") at ../sysdeps/posix/libc_fatal.c:155
#3  0x00007ffff7cafa5a in malloc_printerr (str=str@entry=0x7ffff7db5041 "corrupted double-linked list") at malloc.c:5347
#4  0x00007ffff7cb078c in unlink_chunk (p=p@entry=0x403cc70, av=0x7ffff7de8b80 <main_arena>) at malloc.c:1460
#5  0x00007ffff7cb08f7 in malloc_consolidate (av=av@entry=0x7ffff7de8b80 <main_arena>) at malloc.c:4502
#6  0x00007ffff7cb10c0 in _int_free (av=0x7ffff7de8b80 <main_arena>, p=0x44eecc0, have_lock=<optimized out>) at malloc.c:4400
#7  0x00007ffff71210d4 in iree_allocator_system_free (self=0x0, ptr=0x44fed10) at /usr/local/google/home/laurenzo/src/iree/iree/base/api.c:1111
#8  0x00007ffff711fe1f in iree_allocator_free (allocator=..., ptr=0x44fed10) at /usr/local/google/home/laurenzo/src/iree/iree/base/api.c:1061
#9  0x00007ffff7313109 in iree_hal_heap_buffer_destroy (base_buffer=0x4489510) at /usr/local/google/home/laurenzo/src/iree/iree/hal/buffer_heap.c:69
#10 0x00007ffff712542b in iree_hal_buffer_destroy (buffer=0x4489510) at /usr/local/google/home/laurenzo/src/iree/iree/hal/buffer.c:116
#11 0x00007ffff731bf13 in iree_vm_ref_release (ref=0x3f60710) at /usr/local/google/home/laurenzo/src/iree/iree/vm/ref.c:229
#12 0x00007ffff7130a64 in iree::hal::(anonymous namespace)::HALModuleState::ExSubmitAndWait (this=0x408fd20, device=..., command_buffer=...)
    at /usr/local/google/home/laurenzo/src/iree/iree/modules/hal/hal_module.cc:199
#13 0x00007ffff71360bf in iree::vm::packing::DispatchFunctorVoid<iree::hal::(anonymous namespace)::HALModuleState, iree::vm::ref<iree_hal_device_s> const&, iree::vm::ref<iree_hal_command_buffer_s> const&>::ApplyFn<std::tuple<iree::vm::ref<iree_hal_device_s>, iree::vm::ref<iree_hal_command_buffer_s> >, 0ul, 1ul> (ptr=
    (iree::Status (iree::hal::(anonymous namespace)::HALModuleState::*)(iree::hal::(anonymous namespace)::HALModuleState * const, const iree::vm::ref<iree_hal_device_s> &, const iree::vm::ref<iree_hal_command_buffer_s> &)) 0x7ffff7130790 <iree::hal::(anonymous namespace)::HALModuleState::ExSubmitAndWait(iree::vm::ref<iree_hal_device_s> const&, iree::vm::ref<iree_hal_command_buffer_s> const&)>, self=0x408fd20, 
    params=...) at /usr/local/google/home/laurenzo/src/iree/iree/vm/module_abi_packing.h:609
#14 0x00007ffff7135f9e in iree::vm::packing::DispatchFunctorVoid<iree::hal::(anonymous namespace)::HALModuleState, iree::vm::ref<iree_hal_device_s> const&, iree::vm::ref<iree_hal_command_buffer_s> const&>::Call
    (
    ptr=(void (iree::hal::(anonymous namespace)::HALModuleState::*)(iree::hal::(anonymous namespace)::HALModuleState * const)) 0x7ffff7130790 <iree::hal::(anonymous namespace)::HALModuleState::ExSubmitAndWait(iree::vm::ref<iree_hal_device_s> const&, iree::vm::ref<iree_hal_command_buffer_s> const&)>, self=0x408fd20, stack=0x7fffffff5dd0, call=0x7fffffff4df0, out_result=0x7fffffff5c80)
    at /usr/local/google/home/laurenzo/src/iree/iree/vm/module_abi_packing.h:602
#15 0x00007ffff7140480 in iree::vm::NativeModule<iree::hal::(anonymous namespace)::HALModuleState>::ModuleBeginCall (self=0x13c8a70, stack=0x7fffffff5dd0, call=0x7fffffff4df0, out_result=0x7fffffff5c80)
    at /usr/local/google/home/laurenzo/src/iree/iree/vm/native_module_cc.h:238
#16 0x00007ffff719f144 in iree_vm_bytecode_issue_import_call (stack=0x7fffffff5dd0, call=..., cconv_results=..., dst_reg_list=0x7ffdd400d7c3, out_caller_frame=0x7fffffff59e8, 
    out_caller_registers=0x7fffffff59c8, out_result=0x7fffffff5c80) at /usr/local/google/home/laurenzo/src/iree/iree/vm/bytecode_dispatch.c:465
#17 0x00007ffff719dd50 in iree_vm_bytecode_call_import (stack=0x7fffffff5dd0, module_state=0x40a2c50, import_ordinal=1, caller_registers=..., src_reg_list=0x7ffdd400d7bd, dst_reg_list=0x7ffdd400d7c3, 
    out_caller_frame=0x7fffffff59e8, out_caller_registers=0x7fffffff59c8, out_result=0x7fffffff5c80) at /usr/local/google/home/laurenzo/src/iree/iree/vm/bytecode_dispatch.c:546
#18 0x00007ffff719a749 in iree_vm_bytecode_dispatch (stack=0x7fffffff5dd0, module=0x36977d0, call=0x7fffffff5c88, cconv_arguments=..., cconv_results=..., out_result=0x7fffffff5c80)
    at /usr/local/google/home/laurenzo/src/iree/iree/vm/bytecode_dispatch.c:1120
#19 0x00007ffff7192fd4 in iree_vm_bytecode_module_begin_call (self=0x36977d0, stack=0x7fffffff5dd0, call=0x7fffffff5c88, out_result=0x7fffffff5c80)
    at /usr/local/google/home/laurenzo/src/iree/iree/vm/bytecode_module.c:710
#20 0x00007ffff73189ef in iree_vm_invoke_within (context=0x4496fc0, stack=0x7fffffff5dd0, function=..., policy=0x0, inputs=0x449ba60, outputs=0x408fcc0)
    at /usr/local/google/home/laurenzo/src/iree/iree/vm/invocation.c:154
#21 0x00007ffff731863f in iree_vm_invoke (context=0x4496fc0, function=..., policy=0x0, inputs=0x449ba60, outputs=0x408fcc0, allocator=...) at /usr/local/google/home/laurenzo/src/iree/iree/vm/invocation.c:177
#22 0x00007ffff71037ce in iree::python::VmContext::Invoke (this=0x3f66080, f=..., inputs=..., outputs=...) at /usr/local/google/home/laurenzo/src/iree/bindings/python/pyiree/rt/vm.cc:142
#23 0x00007ffff711b7f3 in pybind11::cpp_function::cpp_function<void, iree::python::VmContext, iree_vm_function_t, iree::python::VmVariantList&, iree::python::VmVariantList&, pybind11::name, pybind11::is_method, pybind11::sibling>(void (iree::python::VmContext::*)(iree_vm_function_t, iree::python::VmVariantList&, iree::python::VmVariantList&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::{lambda(iree::python::VmContext*, iree_vm_function_t, iree::python::VmVariantList&, iree::python::VmVariantList&)#1}::operator()(iree::python::VmContext*, iree_vm_function_t, iree::python::VmVariantList&, iree::python::VmVariantList&) const (this=0xa914f8, c=0x3f66080, args=..., args=..., args=...) at /usr/local/google/home/laurenzo/src/iree/third_party/pybind11/include/pybind11/pybind11.h:78
#24 0x00007ffff711b710 in pybind11::detail::argument_loader<iree::python::VmContext*, iree_vm_function_t, iree::python::VmVariantList&, iree::python::VmVariantList&>::call_impl<void, pybind11::cpp_function::cpp_function<void, iree::python::VmContext, iree_vm_function_t, iree::python::VmVariantList&, iree::python::VmVariantList&, pybind11::name, pybind11::is_method, pybind11::sibling>(void (iree::python::VmContext::*)(iree_vm_function_t, iree::python::VmVariantList&, iree::python::VmVariantList&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::{lambda(iree::python::VmContext*, iree_vm_function_t, iree::python::VmVariantList&, iree::python::VmVariantList&)#1}&, 0ul, 1ul, 2ul, 3ul, pybind11::detail::void_type>(pybind11::cpp_function::cpp_function<void, iree::python::VmContext, iree_vm_function_t, iree::python::VmVariantList&, iree::python::VmVariantList&, pybind11::name, pybind11::is_method, pybind11::sibling>(void (iree::python::VmContext::*)(iree_vm_function_t, iree::python::VmVariantList&, iree::python::VmVariantList&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::{lambda(iree::python::VmContext*, iree_vm_function_t, iree::python::VmVariantList&, iree::python::VmVariantList&)#1}&, std::integer_sequence<unsigned long, 0ul, 1ul, 2ul, 3ul>, pybind11::detail::void_type&&) && (this=0x7fffffff8018, f=...)
    at /usr/local/google/home/laurenzo/src/iree/third_party/pybind11/include/pybind11/cast.h:2002
#25 0x00007ffff711aff6 in pybind11::detail::argument_loader<iree::python::VmContext*, iree_vm_function_t, iree::python::VmVariantList&, iree::python::VmVariantList&>::call<void, pybind11::detail::void_type, pybind11::cpp_function::cpp_function<void, iree::python::VmContext, iree_vm_function_t, iree::python::VmVariantList&, iree::python::VmVariantList&, pybind11::name, pybind11::is_method, pybind11::sibling>(void (iree:
stellaraccident commented 3 years ago

Gets an abort with corrupted double-linked list

I can triage further.

stellaraccident commented 3 years ago

Ok, I nailed this down to bad generated code for one of the convolutions.

==2953662==ERROR: AddressSanitizer: heap-buffer-overflow on address 0x6160001c22c0 at pc 0x7f3538bb56fc bp 0x7f349ffd4380 sp 0x7f349ffd4378                                                                        
READ of size 4 at 0x6160001c22c0 thread T64 (worker[0])                                                                                                                                                            
    #0 0x7f3538bb56fb in main_ex_dispatch_32 /usr/local/google/home/laurenzo/src/scratch/./mobilenet_v3_reproducer.mlir:181                                                                                        

0x6160001c22c0 is located 0 bytes to the right of 576-byte region [0x6160001c2080,0x6160001c22c0)                                                                                                                  
allocated by thread T0 here:                                                                                                                                                                                       
SUMMARY: AddressSanitizer: heap-buffer-overflow /usr/local/google/home/laurenzo/src/scratch/./mobilenet_v3_reproducer.mlir:181 in main_ex_dispatch_32
Shadow bytes around the buggy address:

The referenced line 181 is:

   %165 = "mhlo.convolution"(%20, %143) {batch_group_count = 16 : i64, dimension_numbers = {input_batch_dimension = 3 : i64, input_feature_dimension = 0 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 0 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, output_batch_dimension = 2 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>}, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : (tensor<1x16x16x16xf32>, tensor<1x16x16x16xf32>) -> tensor<3x3x1x16xf32>

My repro scripts, reproducer and notes are snapshotted in this repo: https://github.com/stellaraccident/repro_iree_debugging

stellaraccident commented 3 years ago

Also major thanks to @inho9606 for authoring #4676. I rebased on their patch and got it working, and that was basically the only way to track this down. Major thanks!

hanhanW commented 3 years ago

Good catch! I will take a look.

hanhanW commented 3 years ago

Narrow down the IR to

func @foo(%arg0: tensor<1x16x16x16xf32>, %arg1: tensor<1x16x16x16xf32>) -> tensor<3x3x1x16xf32> {
  %0 = "mhlo.convolution"(%arg0, %arg1) {
    batch_group_count = 16 : i64,
    dimension_numbers = {
      input_batch_dimension = 3 : i64,
      input_feature_dimension = 0 : i64,
      input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
      kernel_input_feature_dimension = 0 : i64,
      kernel_output_feature_dimension = 3 : i64,
      kernel_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
      output_batch_dimension = 2 : i64,
      output_feature_dimension = 3 : i64,
      output_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>
    },
    feature_group_count = 1 : i64,
    lhs_dilation = dense<1> : tensor<2xi64>,
    padding = dense<1> : tensor<2x2xi64>,
    precision_config = ["DEFAULT", "DEFAULT"],
    rhs_dilation = dense<1> : tensor<2xi64>,
    window_strides = dense<1> : tensor<2xi64>
  } : (tensor<1x16x16x16xf32>, tensor<1x16x16x16xf32>) -> tensor<3x3x1x16xf32>
  return %0 : tensor<3x3x1x16xf32>
}

This is a depthwise convolution op. After HLOPreproccesingPass, it will be:

func @foo(%arg0: tensor<1x16x16x16xf32>, %arg1: tensor<1x16x16x16xf32>) -> tensor<3x3x1x16xf32> attributes {iree.module.export} {
  %cst = constant dense<0.000000e+00> : tensor<f32>
  %0 = "mhlo.pad"(%arg0, %cst) {edge_padding_high = dense<[0, 1, 1, 0]> : tensor<4xi64>, edge_padding_low = dense<[0, 1, 1, 0]> : tensor<4xi64>, interior_padding = dense<0> : tensor<4xi64>} : (tensor<1x16x16x16xf32>, tensor<f32>) -> tensor<1x18x18x16xf32>
  %1 = "mhlo.transpose"(%0) {permutation = dense<[3, 1, 2, 0]> : tensor<4xi64>} : (tensor<1x18x18x16xf32>) -> tensor<16x18x18x1xf32>
  %2 = "mhlo.transpose"(%arg1) {permutation = dense<[1, 2, 0, 3]> : tensor<4xi64>} : (tensor<1x16x16x16xf32>) -> tensor<16x16x1x16xf32>
  %3 = "mhlo.convolution"(%1, %2) {
    batch_group_count = 16 : i64,
    dimension_numbers = {
      input_batch_dimension = 0 : i64,
      input_feature_dimension = 3 : i64,
      input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
      kernel_input_feature_dimension = 2 : i64,
      kernel_output_feature_dimension = 3 : i64,
      kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
      output_batch_dimension = 0 : i64,
      output_feature_dimension = 3 : i64,
      output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
    },
    feature_group_count = 1 : i64,
    lhs_dilation = dense<1> : tensor<2xi64>,
    precision_config = ["DEFAULT", "DEFAULT"],
    rhs_dilation = dense<1> : tensor<2xi64>,
    window_strides = dense<1> : tensor<2xi64>
  } : (tensor<16x18x18x1xf32>, tensor<16x16x1x16xf32>) -> tensor<1x3x3x16xf32>
  %4 = "mhlo.transpose"(%3) {permutation = dense<[1, 2, 0, 3]> : tensor<4xi64>} : (tensor<1x3x3x16xf32>) -> tensor<3x3x1x16xf32>
  return %4 : tensor<3x3x1x16xf32>
}

To repro:

$ gh pr checkout 4676

// Build IREE with ASAN

$ build-asan/iree/tools/iree-run-mlir conv_repro.mlir -iree-hal-target-backends=dylib-llvm-aot -export-all -function-input="1x16x16x16xf32=0.0" -function-input="1x16x16x16xf32=0.0" --iree-llvm-sanitize=address

Reassign to @asaadaldien who worked on depthwise conv lowering.

asaadaldien commented 3 years ago

We don't support lowering mhlo.conv with batch_group_count != 1, and I can see in the IR its 16 (we should fail the conversion this is a bug).

One way to see why this will crash is looking at the conv op I/O IR, you can see the output batch_dim is 1 while the input is 16 which isn't what linalg should accept (another place for improvement to add a verifyer for linalg ops)

(tensor<16x18x18x1xf32>, tensor<16x16x1x16xf32>) -> tensor<1x3x3x16xf32>

@phoenix-meadowlark, @stellaraccident is that a training graph ?

stellaraccident commented 3 years ago

Yes, I believe it is a training graph.

asaadaldien commented 3 years ago

Does the inference part of the graph works ?

This will fail compilation with upcoming (https://github.com/google/iree/pull/4799) which also restrict the depthwise-convolution to only work with depth_multiplier =1.

To support conv from backward-pass (batch_group_count > 1) we need to extend the named convOp to support the group_batched version.

benvanik commented 3 years ago

It'd be worth checking if this is still having issues through this path now that @asaadaldien has been testing mobilenetv3.

asaadaldien commented 3 years ago

It'd be worth checking if this is still having issues through this path now that @asaadaldien has been testing mobilenetv3.

The forward pass of mobilenetv3 is fine we are hitting this issue on the backward pass. because we don't support lowering the gradient variants of mhlo.conv op to linalg atm cc:@phoenix-meadowlark

hanhanW commented 2 years ago

We're able to run MobileNetV3 today, closing the issue.

https://github.com/google/iree/blob/2014511033a4a098708490390aeec8e7674a8244/benchmarks/TFLite/CMakeLists.txt#L121-L129