Open Maimonator opened 3 years ago
Also I would love to create a PR if this seems worthy :)
Indeed it's missing from the model dump. I'm not sure how to inject this information into the dump format, feel free to share your opinion.
On the other hand, the information is saved in JSON model format, which can be obtained by classifier.save_model("model.json")
.
Two possible solutions that I think of are:
'class'
at the root node.
This would make the root node look different than other nodes but wouldn't break compatibility for other users that perhaps have no use for that member.get_dump
return an additional value which will contain tree_info
so usage will look like trees, tree_info = classifier.get_dump()
. this will also make it easier to support other formats.I prefer the 2nd option, but I'm not sure what is your policy regarding backwards compatibility.
We could also have an optional parameter for tree_info
which will be a list and fill this list in a C-style fashion:
tree_info = [] # to be filled with tree info
trees = calssifier.get_dump(dump_format="json", tree_info=tree_info)
This wouldn't break any compatibility, but isn't as intuitive.
@trivialfis WDYT?
Sorry for the late reply. @hcho3 Would you like to help taking a look? This will help beyond multi class classification since I also want to add multi target regression.
Both options from @Maimonator looks fine to me.
Ok I'll probably open a PR then in the following days :) Thanks!
The model dump is inherited from sklearn I believe and xgboost didn't invent the format. I don't use model dump very often so might not be a good source of advice. But it might help if we can take a look into sklearn.
I prefer Option 1. There are couple packages that rely on the tree dump (dtreeviz, shap), so it's best to not break backward compatibility.
Hey, sorry for the late reply, pretty busy at work.
@hcho3 what about an optional parameter for get_dump
?
tree_info = [] # to be filled with tree info
trees = calssifier.get_dump(dump_format="json", tree_info=tree_info)
Thanks for following up the discussion. @Maimonator I think for "pythonic" code, it's best to avoid mutating inputs.
Hey there! As taken from the comment here:
Note that each class has its own trees, but there is currently no way to associate a tree with its group. The information has to be available somewhere, otherwise prediction wouldn't work at all, but it's not exported when dumping the model. I took a look and it seems that the information is kept under
GBTreeModel
in thetree_info
member. When callingSaveModel
it is saved, but when callingDumpModel
the only output is the trees and not which group they're associated with.I'd love to hear your input on this. Thanks and I really appreciate your work!