thesps / conifer

Fast inference of Boosted Decision Trees in FPGAs
Apache License 2.0
48 stars 27 forks source link

Fixed <=, > splitting convention mismatch in conifer #76

Open pviscone opened 3 months ago

pviscone commented 3 months ago

It would be good to have a test case.

Sure, I will do it

Let me know what you think about that.

I don't know, defining the methods on the backend side and then applying them on ModelBase (that is something outside the backends) seems a bit dirty to me. I ended up with another solution to avoid messing with the precision and changing the thresholds (so we don't have to deal with the "where and who should do what" dilemma)

  1. I created a simple dictionary splitting_conventions={library:splitting_convention}. e.g. splitting_conventions={"xgboost":"<"}. (I have to verify the convention for onnx, tmva and ydf. I just put <= as a placeholder)
  2. All the converters add to the ensembleDict the library and splitting_convention argument
  3. I added library and splitting_convention to _ensemble_fields (it could be useful for a user to know with which library a conifer model was initially trained)
  4. All the backend objects have now access to splitting_convention and can handle the feature vs threshold comparison in the right way

I tested the Python, cpp, and Xilinix backend and they work and build fine.

I can't test the fpu and vhdl backend (but for them the modifications were the most minimal so they should be ok)

  1. fpu: I just can't, when importing conifer it warns me that he is not loading the modules to run the fpu backends
  2. vhdl: It hangs forever on tree.padTree(self.max_depth) in the constructor (this is not related to my modifications). It is also unsafe, at least with the xgboost models since it calls max_depth which is a dummy value. I think it should be fixed

    https://github.com/thesps/conifer/blob/d90a08c8f9ec80fdc6d38f123cf1836372afd010/conifer/backends/vhdl/writer.py#L48

Check if this solution is ok for you

pviscone commented 3 months ago

P.S. tree.padTree(self.max_depth) is present also in the rolled version of the xilinx backend (that is not used by default, so the default unrolled version works fine)

https://github.com/thesps/conifer/blob/d90a08c8f9ec80fdc6d38f123cf1836372afd010/conifer/backends/xilinxhls/writer.py#L115-L118