chainer / onnx-chainer

Add-on package for ONNX format support in Chainer
MIT License
85 stars 24 forks source link

Retain hook does not support ndarray #216

Closed disktnk closed 5 years ago

disktnk commented 5 years ago

Process of exporting retains references of input value to output temporary values using RetainInputHook. This hook only supports chainer.Variable, so if the target model handles chainer.Chain accepting numpy.ndarray, cause ValueError.

class Model(chainer.Chain):

    def forward(self, x, h):
        return x + h

class ParentModel(chainer.Chain):

    def __init__(self):
        super().__init__()
        self.h = np.array(5, dtype=np.float32)
        with self.init_scope():
            self.m = Model()

    def forward(self, x):
        return self.m(x, self.h)

x  = chainer.Variable(np.array(3, dtype=np.float32))
export(ParentModel(), (x,))
E           ValueError: type <class 'numpy.ndarray'> is not supported to retain input

ref #201