mil-tokyo / webdnn

The Fastest DNN Running Framework on Web Browser
https://mil-tokyo.github.io/webdnn
Other
1.97k stars 146 forks source link

"Gather" operator not implemented yet error #926

Open Wardo82 opened 4 years ago

Wardo82 commented 4 years ago

Looking at the source i have this problem constantly showing up. I want to export a model from pytorch.

raise NotImplementedError("[ONNXConverter] Operator \"Gather\" is not supported yet.")

milhidaka commented 4 years ago

As the error message says, the needed operator is not implemented now. You can implement it by referencing other operator's implementation.

Wardo82 commented 4 years ago

Hi, thanks for the quick reply! Looking at the docs i found

`node = onnx.helper.make_node( 'Gather', inputs=['data', 'indices'], outputs=['y'], axis=1, ) data = np.random.randn(5, 4, 3, 2).astype(np.float32) indices = np.array([0, 1, 3]) y = np.take(data, indices, axis=1)

expect(node, inputs=[data, indices.astype(np.int64)], outputs=[y], name='test_gather_1')`

and came up with

` @ONNXConverter.register_handler("Gather") def _convert_gather(converter: ONNXConverter, onnx_op: INodeProto): x = converter.get_variable(onnx_op.input[0]) indices = converter.get_variable(onnx_op.input[1]) axis = onnx_op.attribute[0].i

new_data = x.data.copy().astype(np.float32)
new_indices = indices.data.copy().astype(np.int64)
y = np.take(new_data, new_indices , axis=axis)

y = ConstantVariable(y, Order(y.shape))
converter.set_variable(onnx_op.output[0], y) `

what do you think? I now get dimensionalities error: ValueError: Unification failed: Number of dimension mismatch (self.ndim) = 2 (other.ndim) = 1