Xtra-Computing / FedTree

A tree-based federated learning system (MLSys 2023)
https://fedtree.readthedocs.io/en/latest/index.html
Apache License 2.0
138 stars 38 forks source link

Multi-class classification problem #66

Open stamatisvas opened 1 year ago

stamatisvas commented 1 year ago

I am using 3 clients to train a model using the following parameters in the clients:

data=./mydata1.csv test_data=./mytest_data.csv model_path=mymodel.model n_parties=3 data_format=csv n_features=11 objective=multi:softmax mode=horizontal partition=0 learning_rate=0.1 max_depth=6 n_trees=50 ip_address=192...

Everything in the training is working but when I am trying to predict, the model predicts only 0.0 while the labels are (0,1,2). I also tried with objective=multi:softprob but the same happened. Also, I tried both the python .predict function and the terminal i.e. /build/bin/FedTree-predict ./predict.conf but I am getting exactly the same results (only 0).

QinbinLi commented 1 year ago

Hi @stamatisvas ,

It seems that you did not specify num_class parameter, which should be equal to the number of classes of your dataset.

stamatisvas commented 1 year ago

Yes, this was the problem and it is now working. Thank you! However, having a trained model I faced the following problems

1) I want to use the Python interface in order to use the trained model and predict the test data (csv). What format should the test data follow? i.e during the training I used a csv file: [id, y, x1, x2, x3, x4, x5, x6, x7, x8, x9]. I tried the usual approaches for the test data like dropping the label and others but I am getting always the error

predictions = federatedcsvmodel.predict(X_test)

2023-05-26 12:06:21,520 INFO [default] #instances = 10845, #features = 11 2023-05-26 12:06:21,526 INFO Performance checkpoint [init trees] for block [predict] : [1 ms] 2023-05-26 12:06:21,527 INFO Performance checkpoint [copy data] for block [predict] : [2 ms ([0 ms] from checkpoint 'init trees')] 2023-05-26 12:06:21,683 INFO Executed [predict] in [158 ms] free(): invalid pointer

2) When I run from the terminal ./build/bin/FedTree-predict ./predict.conf where predict.conf points to the csv test data the output in the terminal is this:

2023-05-26 12:09:06,024 INFO dataset.cpp:396 : loading csv dataset from file ## /media/sf_Shared_Ubuntu_Windows/fedtree_data/test_data.csv ## 2023-05-26 12:09:06,142 INFO dataset.cpp:567 : #instances = 10845, #features = 10 2023-05-26 12:09:06,142 INFO dataset.cpp:581 : Load dataset using time: 0.117392 s 2023-05-26 12:09:06,161 INFO fedtree_predict.cpp:75 : multi-class accuracy = 0.984417

which makes sense (I am getting 99% accuracy with centralized models). However, the prediciton.txt file that is saved in the folder has totally random predictions most of which are 0, and this results in 0% accuracy when checking it manually.