EmmaRocheteau / TPC-LoS-prediction

This repository contains the code used for Temporal Pointwise Convolutional Networks for Length of Stay Prediction in the Intensive Care Unit (https://dl.acm.org/doi/10.1145/3450439.3451860).
https://dl.acm.org/doi/10.1145/3450439.3451860
MIT License
76 stars 25 forks source link

Some questions about performance #7

Closed weicheng113 closed 2 years ago

weicheng113 commented 2 years ago

Hi Emma,

For the figures in performance tables(take table 2 for example), are the scores calculated from test data set(I assume from test data set) or validation data set?

Another question about transfomer model in transfomer_model.py.

class TransformerEncoder(nn.Module):
  ...
  def forward(self, X, T):
    ...
    # question about this line.
    X = self.transformer_encoder(src=X.permute(2, 0, 1), mask=self._causal_mask(size=T))  # T * B * d_model
    ....

Is _causal_mask telling transformer to mask padded data?

Thanks, Cheng

weicheng113 commented 2 years ago

By the way, I use lmdb to replace shuffling(to remove shuffle time) addition to multiple-process data loader and reduce TPC training time to about 4 hours 12 minutes(total time 4:12:33) with aws g4dn.2xlarge NVIDIA T4 GPU 16 GB.

csv2lmdb.py to convert csv to lmdb so that it supports random access.

eicu_dataset2.py is for multiple-process data loader.

EmmaRocheteau commented 2 years ago

Hi! So I didn't report any values from the validation data in the paper. It's all the test dataset. I did check the validation performances to choose hyperparameters though.

The causal mask is to mask values that come from future time points when calculating representations for earlier timepoints. So if I want to calculate a representation for time t using a transformer, and then use that representation to make a mortality and length of stay prediction, I don't want any information from time t + 1, t + 2 etc. from getting to that representation at time t, otherwise it's cheating because it can see into the future!

Ah yes my shuffle implementation was really slow sorry!

weicheng113 commented 2 years ago

Thanks a lot for detailed explanation, Emma.

Ok, got you. That is why the mask is a triangular matrix. Thanks.

weicheng113 commented 2 years ago

Hi Emma,

I did some experiments and there are some difference in ChannelwiseLSTM and Transformer results. If you have some time, please help have a look.

1651471242(1)

The corresponding test_mad error from the paper below:

ChannelwiseLSTM: 2.37
PointwiseOnly: 2.68
TempOnly: 1.91
TPC: 1.78
TPCNoSkip: 1.93
Transformer: 2.36

The experiments for ChannelwiseLSTM(30 epochs) and Transformer(15 epochs) were much worse than the results from the paper. I was using the best parameters from the final_experiment_scripts. The only difference was the batch size, which I modified it to be larger to utilize the GPU. I can see the validation loss for Transformer was still going down. Maybe I can train it a little bit more. Do you have any suggestions on this? The experiments outputs are here.

TempOnly is better than the paper one.

Thanks, Cheng

EmmaRocheteau commented 2 years ago

Hi! I’m out today so I can’t look in a lot of detail. Would it be possible just to test the same batch size on the GPU even though it will be slower to train?

weicheng113 commented 2 years ago

Thanks Emma, I will try that out on Transformer to see if it is better. LSTM is too slow to train :).

EmmaRocheteau commented 2 years ago

Note that the best batch size for each model is different as well! :)

weicheng113 commented 2 years ago

Thanks for the reminder, Emma. I will check carefully.

weicheng113 commented 2 years ago

Hi Emma,

When you have some time, could you please help me with the following questions and understanding?

image

a. Is the unit in y-axis count? if so, from b and c below, Remaining LoS should have much more count than Total LoS? b. Total LoS: is Total LoS the actualiculos from apachepatientresult, which means 1 value for each patient. And the plot is the histogram of actualiculos values. c. Remaining LoS: Is this the values from preprocessing, which means 1 value per hour and multiple values for each patient? And the plot is histogram of the values.

The following is my surface understanding of the reasoning for use of MSLE. d. For shorter length of stays, more accuray prediction is required. And for longer length of stays, larger prediction error allowed. Penalty is proportional to expected actual length of stays. Less tolerance for shorter length of stays and more tolerance for longer length of stays.

By the way, for transformer model, I set 'batch_size=32' the same as original best_hyperparameters.py and the performance was not improved at 15th epoch(similar to the one posted above. ValidationMean absolute deviation (MAD) = 2.617043680500727). I set it to total 30 epochs so it is still going but it does not looks promising.

Thanks, Cheng

weicheng113 commented 2 years ago

Hi Emma,

Below is the val_mad from 30 epochs. The lowest point was around 2.62.

image

The configuration used is as follows:

2022-05-03 04:01:33,516 - INFO - Config:
2022-05-03 04:01:33,516 - INFO - {
    "L2_regularisation": 0,
    "alpha": 100,
    "base_dir": "models/experiments/final/eICU/LoS/Transformer",
    "batch_norm": "mybatchnorm",
    "batch_size": 32,
    "batch_size_test": 32,
    "batchnorm": "mybatchnorm",
    "d_model": 16,
    "dataset": "eICU",
    "diagnosis_size": 64,
    "disable_cuda": false,
    "exp_name": "Transformer",
    "feedforward_size": 256,
    "intermediate_reporting": false,
    "labs_only": false,
    "last_linear_size": 17,
    "learning_rate": 0.00017,
    "loss": "msle",
    "main_dropout_rate": 0.45,
    "mode": "train",
    "n_epochs": 30,
    "n_heads": 2,
    "n_layers": 6,
    "name": "Transformer",
    "no_diag": false,
    "no_exp": false,
    "no_labs": false,
    "no_mask": false,
    "percentage_data": 100.0,
    "positional_encoding": false,
    "save_results_csv": false,
    "seed": 284398694,
    "shuffle_train": true,
    "sum_losses": true,
    "task": "LoS",
    "trans_dropout_rate": 0
}

Detailed experiment information is here

Thanks, Cheng

EmmaRocheteau commented 2 years ago

Hi Cheng,

I'll deal with the questions first.

a. Is the unit in y-axis count? if so, from b and c below, Remaining LoS should have much more count than Total LoS?

Yes it's count, and yes remaining los has much more than total los. I've removed the actual values because the plot was focusing on the distribution but your understanding is correct!

b. Total LoS: is Total LoS the actualiculos from apachepatientresult, which means 1 value for each patient. And the plot is the histogram of actualiculos values.

Yes

c. Remaining LoS: Is this the values from preprocessing, which means 1 value per hour and multiple values for each patient? And the plot is histogram of the values.

Yes that's correct

The following is my surface understanding of the reasoning for use of MSLE. d. For shorter length of stays, more accuray prediction is required. And for longer length of stays, larger prediction error allowed. Penalty is proportional to expected actual length of stays. Less tolerance for shorter length of stays and more tolerance for longer length of stays.

Yes that's correct.

By the way, for transformer model, I set 'batch_size=32' the same as original best_hyperparameters.py and the performance was not improved at 15th epoch(similar to the one posted above. ValidationMean absolute deviation (MAD) = 2.617043680500727 ). I set it to total 30 epochs so it is still going but it does not looks promising.

Okay I will have a look now

weicheng113 commented 2 years ago

Thanks a lot for the confirmation, Emma.

I just tried multi-task TPC. With mortality as side task, the result looks good.

train_mad: 1.756100856679557
val_mad: 1.2926345455305368  # I have not ran test yet, but should be similar.
EmmaRocheteau commented 2 years ago

Hi Cheng,

Sorry for the delay. I didn't get time to look into it today. Something came up. I will get back to you when I can

Emma :)

weicheng113 commented 2 years ago

Thanks Emma. Not urgent, you can check when you have time. I think it will take some time to check this. I am working on other things.

weicheng113 commented 2 years ago

Hi Emma,

I am trying to understand the resample_and_mask logic in timeseries.py. When you have time, please help me to confirm my following understanding(It is difficult to debug this part of code to confirm the logic, as it needs to go through data reconfiguring process, so I am trying to understand it by going through the code).

The time series for each patient are resampled hourly and mean of the values falling in a particular hour is the final value for that hour. For missing values in any hours, they are forward filled.

Assume we have an field f and its corresponding decay indicator field m. At hour 1, f has value v1, then m will be 1.0 to indicate there is a value for f at hour 1. At hour 2, f does not have value, but was forward filled. So f also has value v1 and m=0.75. At hour 3, f does not have value again, f=v1, m=0.75*0.75. At hour 4, f has value, so f=v4, m=1.0. Decay indicator field is to tell how much we can trust the corresponding value or how stale/fresh the corresponding value is. Is my understanding for decay indicator field correct?

image

a. For st3_mask, as there are values for each hour, so the values for its decay indicator field is all 1.0s. b. For the first hour, why do we get decay indicator value of 0.0 for systemicdiastolic_mask, should it be 1.0? c. For the first hour, why do we get decay indicator value of 0.187500 for total bilirubin_mask. should it be 1.0, as it has value at first hour? d. For total bilirubin_mask, as we have value for hour 4, its decay indicator value is set to 1.0.

By the way, I am currently doing an experiment to include decay indicator field also into skip_connection to see if it can be of any help.

Thanks, Cheng

weicheng113 commented 2 years ago

Hi Emma,

Just let you know I got some promising result with additional mask field into skip connection(the first result below). I am not sure if it is because of increase in network weight or mask field does help in this case.

test_mad: 1.2677834890621544  # TPC skip connection with feature value and mask field; about 1.15M parameters.

test_mad: 1.65898419174147    # TPC skip connection with feature value; about 1.09M parameters.
test_mad: 1.2416521974053145  # TPC with multitask 

The result is here and modified model file is here

I will try it with multitask tomorrow to see if it will get some improvement with multitask also.

EmmaRocheteau commented 2 years ago

Just to let you know that I will get back to you soon! Just juggling some things today

weicheng113 commented 2 years ago

Hi Emma,

Not urgent. You can have a look when you have time. Thanks a lot.

weicheng113 commented 2 years ago

Hi Emma,

Just let you know TPC multitask with additional mask field into skip connection did not improve the performance(performance is similar to TPC multitask). I saw something strange in confusion matrices for the most of epochs, something like below.

# incorrect prediction for all the died cases.
[[12964     0]
 [ 1347     0]]

The experiment output is here. I am currently guessing it may be to do with alpha. Will need to do some debugging.

This is just information to share with you. Thanks for all the help.