ankane / onnxruntime-ruby

Run ONNX models in Ruby
MIT License
111 stars 5 forks source link

When input is bool tensor #6

Closed mib32 closed 3 years ago

mib32 commented 3 years ago

I am trying to make it work with a network that accepts the input as a boolean tensor, but something is wrong.

In inference_session.rb:226 there is a code

if tensor_type == :bool
  tensor_type = :uchar
  flat_input = flat_input.map { |v| v ? 1 : 0 }
end

So it detects a 'bool' type from the ONNX model, which means that the model is designed to accept bool. Then, it sets the type to uchar. And for me, what happens next, is that continued inference produces error OnnxRuntime::Error: type 17 is not supported in this function, and as I understand, that kinda makes sense.

One workaround would be to I guess make the ONNX model accept tensor as uchar, and inside of it's forward function convert it back to bool. But for some reason I get weird and inconsistent Gather errors from that. And even more, for this I need to change the architecture of the model, that in the end I actually can not use it for the models, that were trained with network prior to changing architecture.

Other thing I tried, is to do this there (took that code from FFI::Pointer#write_array_of_type)

if tensor_type == :bool
    size = ::FFI.type_size(::FFI::TYPE_BOOL)
    flat_input.each_with_index { |val, i|
      break unless i < input_tensor_values.size
      input_tensor_values.write(::FFI::TYPE_BOOL, val)
    }
end

But that totally doesn't work.

I would love to hear if you have any experience with this. What I don't understand is why it's not possible to just send the array of bools natively, why you even had to make this case for if tensor_type == :bool and convert them to bytes?

mib32 commented 3 years ago

Also asked here at FFI https://github.com/ffi/ffi/issues/867

mib32 commented 3 years ago

Sorry, my fault. It actually does work okay using

if tensor_type == :bool
    size = ::FFI.type_size(::FFI::TYPE_BOOL)
    flat_input.each_with_index { |val, i|
      break unless i < input_tensor_values.size
      input_tensor_values.write(::FFI::TYPE_BOOL, val)
    }
end

I recommend to do it as a default for bool type instead of what there is now, I can send the PR

ankane commented 3 years ago

Hey @mib32, I fixed an error and added tests for the bool type, but feel free to send a PR if there's still an issue or a better approach.

mib32 commented 3 years ago

@ankane Thanks, it's perfect now   👏

ankane commented 3 years ago

Great, just pushed out a new release.