microsoft / CNTK

Microsoft Cognitive Toolkit (CNTK), an open source deep-learning toolkit
https://docs.microsoft.com/cognitive-toolkit/
Other
17.52k stars 4.28k forks source link

Save User Function. #3082

Open doggy1232 opened 6 years ago

doggy1232 commented 6 years ago

I have created a user function that runs on the GPU from a DLL through C++ in python. I registered it using register_native_user_function and I am using it successfully to forward and backward propagate and train a neural network.

Now I want to create checkpoints while training and save a final model. My custom functions don't have any data that they need to save so I thought that this would work fine but when I run Trainer.save_checkpoint(filename) I get this error.

ValueError: UserFunction with op name 'CustomOp' has not been registered. [CALL STACK] CNTK::Internal:: UseSparseGradientAggregationInDataParallelSGD

  • CNTK::Function:: ModuleName
  • CNTK::Internal:: RegisterUDFDeserializeCallbackWrapper (x2)
  • CNTK::Internal:: UseSparseGradientAggregationInDataParallelSGD
  • CNTK::Function:: Save
  • CNTK::Trainer:: RestoreFromCheckpoint
  • CNTK::Trainer:: SaveCheckpoint
  • PyInit__cntk_py (x2)
  • PyCFunction_Call
  • PyFunction_FastCallDict
  • PyEval_EvalFrameDefault (x2)
  • PyEval_GetFuncDesc (x2)

first thing I do in the code is call this function

OP_id = "CustomOp" C.register_native_user_function(OP_id, "OP_DLL", "CreateCustomOp")

Then later in the model definition I call this

layer3 = C.native_user_function(OP_id, [layer1, layer2], None, "CustomOpLayer")

I don't understand why it's saying the op is not registered. I even tried to register it directly before I called save_checkpoint and it generated an error saying the op was already registered.

ke1337 commented 6 years ago

Did you implement serialize/deserialize? Please take a look at user function test as an example.

doggy1232 commented 6 years ago

Do I need it? As I said the function doesn't have any parameters to save, Should I just implement an empty serialize and deserialize method?