dmlc / MXNet.cpp

C++ interface for mxnet
Other
114 stars 78 forks source link

inline function “_MinusScalar” fails to work. #39

Closed LynetteXing1991 closed 7 years ago

LynetteXing1991 commented 8 years ago

When writing the gru unit, I need to use "1-update" matrix. Following is my source code: Symbol gruLayer::step(Symbol &input) { Symbol reset_i = FullyConnected("reset_i", input, inputWeights[0], biases[0], layerSize); Symbol reset = Activation("reset_act", reset_i + FullyConnected("reset_h", previousHidden, hiddenWeights[0], biases[0], layerSize, true), ActivationActType::sigmoid); Symbol update_i = FullyConnected("update_i", input, inputWeights[1], biases[1], layerSize); Symbol update = Activation("update_act", update_i + FullyConnected("update_h", previousHidden, hiddenWeights[1], biases[1], layerSize, true), ActivationActType::sigmoid); Symbol fake_hidden_i = FullyConnected("cell_i", input, inputWeights[2], biases[2], layerSize); Symbol fake_hidden_h = FullyConnected("cell_h", previousHidden_reset, hiddenWeights[2], biases[2], layerSize, true); Symbol fake_hidden = Activation("cell_act", fake_hidden_h + fake_hidden_i, ActivationActType::tanh); //Symbol tmp = update - (mx_float)1.9f; hidden = update_previousHidden - (update - (mx_float)1.0f)* fake_hidden; hidden = update*previousHidden; previousHidden = hidden; return hidden; }

But the "update - (mx_float)1.0f" fails at:

Symbol Operator::CreateSymbol(const std::string &name) { const char *pname = name == "" ? nullptr : name.c_str();

SymbolHandle symbol_handle; std::vector<const char > input_keys; std::vector<const char > param_keys; std::vector<const char *> param_values;

for (auto &data : params_) { param_keys.push_back(data.first.c_str()); param_values.push_back(data.second.c_str()); } for (auto &data : this->input_keys) { input_keys.push_back(data.c_str()); } const char **input_keys_p = (input_keys.size() > 0) ? input_keys.data() : nullptr;

MXSymbolCreateAtomicSymbol(handle_, param_keys.size(), param_keys.data(), param_values.data(), &symbol_handle); MXSymbolCompose(symbol_handle, pname, input_values.size(), input_keys_p, input_values.data()); return Symbol(symbol_handle); }

at MXSymbolCreateAtomicSymbol(handle_, param_keys.size(), param_keys.data(), param_values.data(), &symbol_handle); It can't get a symbol_handle from this function. Do any one meet the same problem with me and find the solution?

zhangchen-qinyinghua commented 8 years ago

I didn't find bugs around _MinusScalar. However, I think you shouldn't set no_bias to be true in FullyConnected operator when you actually need a bias. If you really don't want a bias, try:

auto reset_h = Operator("FullyConnected")                                                                       
         .SetParam("num_hidden", layerSize)
         .SetParam("no_bias", true)
         .SetInput("data", previousHidden)
         .SetInput("weight", hiddenWeights[0])
         .CreateSymbol("reset_h");

No mater what, we really want a RNN example.

LynetteXing1991 commented 8 years ago

Thank you very much for the no_bias writing.

But are you sure there is no bug in _MinusScalar? Because I run only the _MinusScalar function as follows,

image

It would come up a exception: image The MXSymbolCreateAtomicSymbol can't return a proper symbol handle

LynetteXing1991 commented 8 years ago

Directly using the _MinusScalar operator like: auto fake_hidden_h = Operator("_MinusScalar") .SetParam("scalar_on_left", false) .SetParam("scalar", 1.0) .SetInput("lhs", sym_input) .CreateSymbol("fake_hidden_h");

also fails.....

zhangchen-qinyinghua commented 8 years ago

@lx75249 @hjk41 Please help to fix this issue?

hjk41 commented 8 years ago

This works fine for me: update - (mx_float)1.0f

This fails:

auto fake_hidden_h = Operator("_MinusScalar")
.SetParam("scalar_on_left", false)
.SetParam("scalar", 1.0)
.SetInput("lhs", sym_input)
.CreateSymbol("fake_hidden_h");

It should be written like this:

auto fake_hidden_h = Operator("_MinusScalar")
      .SetParam("scalar", 1.0)
      .SetInput("data", outputs[nLayers - 1])
      .CreateSymbol("fake_hidden_h");
hjk41 commented 8 years ago

@LynetteXing1991 Since I cannot reproduce your problem. Could you recompile a libmxnet.dll with Debug flag and then add a breakpoint in c_api_error.h:36? You will then be able to get the error message which should tell us more information.