mwydmuch / napkinXC

Extremely simple and fast extreme multi-class and multi-label classifiers.
MIT License
64 stars 7 forks source link

Support for custom tree (tree_structure in python interface) #16

Closed xiaohan2012 closed 3 years ago

xiaohan2012 commented 3 years ago

Hi,

Thanks for writing this software, which is very useful!

I'm currently experimenting with the effect of label trees and wish to load trees from file.

Is it possible to pass a string to thetree_structure parameter in models.PLT class, so that a custom tree can be loaded? It seems like the current Python interface does not support it.

If possible, I can make a pull request, and it would be nice if some instructions can be given, e.g., where and what to modify.

Cheers, Han

mwydmuch commented 3 years ago

Hi @xiaohan2012, thank you for your kind words.

It's true, that the Python interface doesn't have the tree_structure parameter right now (I would like to improve this functionality first before adding it), but constructors accept **kwargs that are passed to the underlying CPP module. This actually allows using all undocumented parameters that are implemented in src/args.cpp file also from Python (so for experimental purposes new options can be implemented just in CPP, without the need of updating the Python module).

So you can use the tree_structure out of the box, like in this example below that trains two PLTs: first trains one constructed its tree using hierarchical k-means clustering, the second one loads a tree created by for the first one.

from napkinxc.datasets import load_dataset
from napkinxc.models import PLT
from napkinxc.measures import precision_at_k

X_train, Y_train = load_dataset("eurlex-4k", "train")
X_test, Y_test = load_dataset("eurlex-4k", "test")
plt = PLT("eurlex-model")
plt.fit(X_train, Y_train)

Y_pred = plt.predict(X_test, top_k=5)
print("Precision at k:", precision_at_k(Y_test, Y_pred, k=5))

plt2 = PLT("eurlex-model2", tree_structure="eurlex-model/tree", verbose=True) # I added the verbose option here as a  proof, it will print the confirmation that the tree was loaded from a given file.
plt2.fit(X_train, Y_train)
Y_pred = plt2.predict(X_test, top_k=5)
print("Precision at k:", precision_at_k(Y_test, Y_pred, k=5)) # Since the tree was the same, this should give very similar results

When it comes to the tree format, it's pretty strict and limited right now:

xiaohan2012 commented 3 years ago

Thank you, it works!