microsoft / ELL

Embedded Learning Library
https://microsoft.github.io/ELL
Other
2.29k stars 294 forks source link

python protonn help #220

Closed FoxMarts closed 5 years ago

FoxMarts commented 5 years ago

Hello, I'm creating a protonn model to run on the Corte M4 Arm in C++. I am using python to train and currently creates a model receives and returns vectors and wanted to know if it is possible to change so that the model only works with floats.

lovettchris commented 5 years ago

Ah, good question. We should make that a switch on the protonn trainer command line. The problem is the class ProtoNNPredictorNode is not templatized on the type, it is hard coded to double. We'll need to make that a template, and then fix Map ProtoNNPredictor::GetMap() const so that it can build a float based or a double based Map, perhaps it should take a dtype argument. Question is whether you also want to "train" with floats. If so, the change is deeper, and the protonntrainer will need to change, but presumably you can train with doubles, then just cast the resulting weights to floats....

FoxMarts commented 5 years ago

Thanks for the quick response. I changed the model.ell file directly by turning all the values ​​from doubles to float and changing the type of all inputs, outputs and nodes to float, ie InputNode for InputNode , but now when I use the wrap it appears " exception: mismatched types for function. "

lovettchris commented 5 years ago

right, unfortunately it is not that simple because of the implementation of ProtoNNPredictor::GetMap() as mentioned above.

FoxMarts commented 5 years ago

Thanks for the answer, I've already changed my code to use double. Now I only have one doubt regarding the header created, so I realized I can compile for C, but I do not know what to send in the first argument of void model_Predict(void context, double input, double * output);

lovettchris commented 5 years ago

just pass nullptr for the first parameter. This is a context object that is only meaningful if your model has "callbacks" during predict. Callbacks are created by SouceNodes and SinkNodes.