dmlc / minpy

NumPy interface with mixed backend execution
https://minpy.readthedocs.io/en/latest/
Other
1.11k stars 112 forks source link

Fix is_train and multiple outputs for mxnet symbol #168

Closed Taco-W closed 7 years ago

Taco-W commented 7 years ago

@GaiYu0 @sneakerkg @jermainewang Guys, please check.

Example to disable is_train (its default value is true)

  net = mx.sym.Flatten(net);
  net = mx.sym.FullyConnected(data=net, name='fc1', num_hidden=hidden_size);
  net = mx.sym.Activation(data=net, act_type='relu')
  net = mx.sym.FullyConnected(data=net, name='fc2', num_hidden=num_classes)
  net = mx.sym.SoftmaxOutput(data=net, name='softmax', normalization='batch')
  input_shapes = {'X': (batch_size,) + input_size, 'softmax_label': (batch_size,)} 
  self.fwd_fn = core.Function(net, input_shapes=input_shapes)
  self.fwd_fn.is_train(False)

Example to use multi-outputs Say the above network has multiple outputs, you could access the outputs with index

  self.fwd_fn(inputs)[i]

When network is single-outputted, indexing would be unnecessary (back compatible). And you could just reference the value by

  self.fwd_fn(inputs)
GaiYu0 commented 7 years ago

Yes, it works. Thank you very much!