BalticBytes / vertical-federated-learning-kang

5 stars 1 forks source link

Torch model running on adult does not seem to learn #3

Closed lbhm closed 1 year ago

lbhm commented 1 year ago

Hi @BalticBytes,

following our discussion in the Flower repo, I pulled your code as a baseline to implement a VFL prototype for my own use case.

However, I got a bit confused that the configured model does not seem to learn anything when training on the adult dataset. Have you made similar observations?

The only local changes that I made to your code are updating the torch version (because I run Python 3.10) and adding a to_numpy() call in two lines.

This is an excerpt of the output I received:

~/Projects/vertical-federated-learning-kang (main*) » python torch_vertical_FL_train.py --dname ADULT --model_type vertical 
...
Start vertical FL......

For the 1-th epoch, train loss: 0.32272234559059143, test auc: 0.9111101820607762
For the 2-th epoch, train loss: 0.3679255545139313, test auc: 0.8912106874896119
For the 3-th epoch, train loss: 0.12276240438222885, test auc: 0.8886073102771038
For the 4-th epoch, train loss: 0.239824578166008, test auc: 0.8806176046638836
For the 5-th epoch, train loss: 0.07423055917024612, test auc: 0.878369684654863
For the 6-th epoch, train loss: 0.05435499921441078, test auc: 0.8821973720347704
For the 7-th epoch, train loss: 0.15234878659248352, test auc: 0.875332326395491
For the 8-th epoch, train loss: 0.15903599560260773, test auc: 0.8771023581842844
For the 9-th epoch, train loss: 0.11753567308187485, test auc: 0.8788756352858917
For the 10-th epoch, train loss: 0.08219484239816666, test auc: 0.8721860274393107
For the 11-th epoch, train loss: 0.05476965382695198, test auc: 0.8822209164610665
For the 12-th epoch, train loss: 0.09200111031532288, test auc: 0.8814176060892759
For the 13-th epoch, train loss: 0.04584350064396858, test auc: 0.875801178646707
For the 14-th epoch, train loss: 0.23029093444347382, test auc: 0.879850692702475
For the 15-th epoch, train loss: 0.03156660869717598, test auc: 0.8709582810583436
For the 16-th epoch, train loss: 0.24415796995162964, test auc: 0.8727847558366633
For the 17-th epoch, train loss: 0.04956258833408356, test auc: 0.8775121584690063
For the 18-th epoch, train loss: 0.0777658000588417, test auc: 0.8742061392155202
For the 19-th epoch, train loss: 0.21288323402404785, test auc: 0.8769827270452661
For the 20-th epoch, train loss: 0.11331519484519958, test auc: 0.8769915721135233
For the 21-th epoch, train loss: 0.032309386879205704, test auc: 0.8721464473496994
For the 22-th epoch, train loss: 0.13203926384449005, test auc: 0.8746273935454674
For the 23-th epoch, train loss: 0.05802006646990776, test auc: 0.8678184727277911
For the 24-th epoch, train loss: 0.0868304967880249, test auc: 0.8756367494641416
For the 25-th epoch, train loss: 0.09309358149766922, test auc: 0.8724383982357715
For the 26-th epoch, train loss: 0.1256425380706787, test auc: 0.8709678260960312
For the 27-th epoch, train loss: 0.06171223521232605, test auc: 0.8738337554785334
For the 28-th epoch, train loss: 0.011140420101583004, test auc: 0.8744009216179323
For the 29-th epoch, train loss: 0.04139363393187523, test auc: 0.8762606131274049
For the 30-th epoch, train loss: 0.116118423640728, test auc: 0.8705610165897846
...

The same happens when I use --model_type centralized -- the train loss is bouncing around while the Test AUC stagnates around 0.87-0.88.

Before diving into the dataset and tweaking hyper parameters, I first wanted to confirm that I can roughly reproduce what you saw in your experiments.

BalticBytes commented 1 year ago

Start vertical FL......

For the 1-th epoch,  train loss: 0.273951917886734,    test auc: 0.9117442270975729
For the 2-th epoch,  train loss: 0.1786241978406906,   test auc: 0.900648693487968
For the 3-th epoch,  train loss: 0.13361237943172455,  test auc: 0.8888936614077327
For the 4-th epoch,  train loss: 0.13391605019569397,  test auc: 0.8804256857727777
For the 5-th epoch,  train loss: 0.09747820347547531,  test auc: 0.883832882425753
For the 6-th epoch,  train loss: 0.08075136691331863,  test auc: 0.8767451192404287
For the 7-th epoch,  train loss: 0.032229404896497726, test auc: 0.8739121520547413
For the 8-th epoch,  train loss: 0.11218876391649246,  test auc: 0.883627664115469
For the 9-th epoch,  train loss: 0.055012766271829605, test auc: 0.8771337931750689
For the 10-th epoch, train loss: 0.09364917874336243,  test auc: 0.8804944100441285
For the 11-th epoch, train loss: 0.09003350883722305,  test auc: 0.8719841817090097
For the 12-th epoch, train loss: 0.1260099858045578,   test auc: 0.8778461711544889
For the 13-th epoch, train loss: 0.1729588657617569,   test auc: 0.8783605850522673
For the 14-th epoch, train loss: 0.09362012147903442,  test auc: 0.8724239534120709
For the 15-th epoch, train loss: 0.030503258109092712, test auc: 0.874701208503585
For the 16-th epoch, train loss: 0.10761567950248718,  test auc: 0.8750783011258308
For the 17-th epoch, train loss: 0.023991085588932037, test auc: 0.8798602377401628
For the 18-th epoch, train loss: 0.04462840035557747,  test auc: 0.8764336964774738
For the 19-th epoch, train loss: 0.10871774703264236,  test auc: 0.8694285296849399
For the 20-th epoch, train loss: 0.049978107213974,    test auc: 0.8761117105394779
For the 21-th epoch, train loss: 0.05989370122551918,  test auc: 0.8678467260393464
For the 22-th epoch, train loss: 0.04775784909725189,  test auc: 0.8730275179618519
For the 23-th epoch, train loss: 0.08232229948043823,  test auc: 0.8760430499017116
For the 24-th epoch, train loss: 0.030299151316285133, test auc: 0.8719921995406673
For the 25-th epoch, train loss: 0.031436219811439514, test auc: 0.8775824099463876
For the 26-th epoch, train loss: 0.123594731092453,    test auc: 0.871307311269788
For the 27-th epoch, train loss: 0.05187857151031494,  test auc: 0.8689973485157976
For the 28-th epoch, train loss: 0.07397286593914032,  test auc: 0.8725011409501716
For the 29-th epoch, train loss: 0.07121347635984421,  test auc: 0.8681039329882356
For the 30-th epoch, train loss: 0.13118626177310944,  test auc: 0.867342111713594```
lbhm commented 1 year ago

I see. Thank you for confirming my numbers.

The figure intuitively looks like a reasonable training progress to me but the logs we are seeing feel off. The AUC is highest after the initial epoch and the training progress shows no visible improvement. I am not really familiar with the dataset so I can only guess but on first sight I would say that the DNN layers and/or hyper parameters are not configured well by the original authors.

BalticBytes commented 1 year ago

I recall something about the dimensionality of the training data. In some configuration, after one-hot-encoding I believe, the dimensions of the training data would be in the twenty-thousands.

Maybe you want to check the dimensions of the data before training.

lbhm commented 1 year ago

I noticed two minor bugs in your code while looking into this:

The cause of the high dimensionality is that all columns (even the numerical) ones get one-hot encoded. I fixed this by adding a StandardScaler from sklearn and modifying the code like this:

data_splits[i] = X[attribute_groups[i]]
encoded = encoder.fit_transform(data_splits[i].select_dtypes(exclude="number"))
scaled = scaler.fit_transform(data_splits[i].select_dtypes(include="number"))
encoded_data_splits[i] = np.concatenate([encoded, scaled], axis=1)

The test AUC increases to 0.90-0.91 after these changes but the model still does not seem to learn anything because it reaches this AUC after the first epoch and then stagnates. Therefore, I assume that (a) the dataset is too simple and the model gets as good as it can after the first epoch or (b) the DNN layers and hyper parameters are not well configured. In any way, I would have expected the authors of the original publication to have investigated this.

lbhm commented 1 year ago

FYI: I created a repo of my own to collect my adaption of this use case and the Flower VFL implementation that I am working on. I referenced your work and the original authors so I hope that this is ok for you (even though your work technically does not have a license 😉 ). Please let me know if that is not ok for you or if you prefer any changes to the reference.

BalticBytes commented 1 year ago

The authors did not provide the data. The data in the repository actually comes from a third party.

So I do think there are some subtle differences, but the results stay close enough with each other.