EmmaRocheteau / TPC-LoS-prediction

This repository contains the code used for Temporal Pointwise Convolutional Networks for Length of Stay Prediction in the Intensive Care Unit (https://dl.acm.org/doi/10.1145/3450439.3451860).
https://dl.acm.org/doi/10.1145/3450439.3451860
MIT License
76 stars 25 forks source link

some conceptual questions about temp_pointwise #6

Closed weicheng113 closed 2 years ago

weicheng113 commented 2 years ago

Hi Emma,

I got some conceptual questions regarding temp_pointwise implementation. I marked 3 steps in the following source code for questions. The comments are my understanding and there are 4 lines below extracted from your source code.

def temp_pointwise(...):
  ...
  # temp_skip(batch_size, ts_feature_value_dim, ts_feature_conv_dim+1, n_measure_of_patient)
  # temp_skip is combination of temporal convolution and skip connection. Each ts_feature_conv_dim(12 values in 1 layer) values are 
  # concatenated with a feature value from skip connection.
  # step 1
  temp_skip = cat((point_skip.unsqueeze(2),  # B * (F + Zt) * 1 * T
                         X_temp.view(B, point_skip.shape[1], temp_kernels, T)),  # B * (F + Zt) * temp_kernels * T
                        dim=2)  # B * (F + Zt) * (1 + temp_kernels) * T

  # point_output(batch_size * n_measure_of_patient, point_size)
  #   -> view(batch_size, n_measure_of_patient, point_size, 1)
  #   -> permute(batch_size, point_size, 1, n_measure_of_patient)
  #   -> X_point_rep(batch_size, point_size, ts_feature_pattern_dim+1, n_measure_of_patient)
  # X_point_rep contains representation of each measure in low-dimensional space
  # step 2
  X_point_rep = point_output.view(B, T, point_size, 1).permute(0, 2, 3, 1).repeat(1, 1, (1 + temp_kernels), 1)  # B * point_size * (1 + temp_kernels) * T

  # X_combined(batch_size, ts_feature_value_dim + point_size, ts_feature_conv_dim+1, n_measure_of_patient)
  # temp_skip and X_point_rep are concatenated along ts_feature_value_dim axis.
  # step 3
  X_combined = self.relu(cat((temp_skip, X_point_rep), dim=1))  # B * (F + Zt) * (1 + temp_kernels) * T
  next_X = X_combined.contiguous().view(B, (point_skip.shape[1] + point_size) * (1 + temp_kernels), T)  # B * ((F + Zt + point_size) * (1 + temp_kernels)) * T
  ...

At step 3 X_combined, my understanding for the reason of concatenating temp_skip and X_point_rep along ts_feature_value_dim is that X_point_rep contains representation at ts_feature_value_dim level. If so, why don't do the following:

X_combined = self.relu(cat(
      (temp_skip.view(B, point_skip.shape[1] * (temp_kernels+1), T),
      point_output.view(B, T, point_size).permute(0, 2, 1)  # B * point_size * T
      ),
  dim=1
 )

So flatten temp_skip so that it can be concatenated with point_output at ts_feature_value_dim level.

I actually have difficulty in understanding the reasoning to repeat each point_size value (1+temp_kernals) times at step 2 X_point_rep. The only reason I can think of is to match the dimension with temp_skip. But with the repeation, will next_X contain (1+temp_kernals) repeated value at dim=1, which will not add information for network?

Asking source code in text is a bit difficult. I am not sure if I state my question clearly.

Thanks in advance for your time and help, Cheng

EmmaRocheteau commented 2 years ago

Hi Cheng,

Thanks for your question! After reading it a few times I see what you mean. I'm going to refer to the channel dimension as conv_dim since it seems to fit your naming system.

So we want to preserve the structure of ts_feature_value_dim (i.e. not just append conv_dims to it) so that each feature in this axis refers to one variable. This could be heart rate, if we are looking at the original set in F, or it could be an output of a previous pointwise layer i.e. a new variable which has been informed by a variety of the original variables e.g. one such variable could be weighted to indicating "lung health", while another could be primarily concerned with the kidney function. These features would be represented in the Zt component of the ts_feature_value_dim.

If we were to just append each conv_channel to the ts_feature_value_dim, then we are effectively treating each conv output channel as a new feature, so the conv part of the model wouldn't be aware that some of the outputs of the previous conv_channels are related to one another in a structured way. The problem with this is that as we stack the convolutions on top of one another, we wouldn't be able to get such complex temporal processing given to us by having multiple conv channels that can extract different temporal signals using the same kernel dimensions on the same variable.

I don't know if that makes sense? I'm happy to go back and give some further explanation on my thoughts if it doesn't. The difficulty with that approach (as you correctly point out!) is that we only have one value for the pointwise outputs when they are first formed. So we have repeated it to fit the conv dimension so that the sizing works for the next conv layer. The model will need to work out that there is no useful signal in the conv dimension for those and just focus on creating useful signals from the temporal dimension. In all future layers however, the model can combine information from a variety of previous conv channels to form the next layer of processing on that feature. As a side note, there may have been a way to handle the newly added pointwise outputs in a special way such that they were treated as having a channel dimension of 1 while the others had a dimension of conv_dim. I didn't go down that route but perhaps it would be a more "ideal" way of handling it, so that the model doesn't need to do any extra work. There may have been a better reason at the time why I didn't other than "it's more faff and it doesn't work neatly with the library I'm using", but if there is I've forgotten it. In any case I think the model should easily be able to handle a few repeated values in the conv channels.

Let me know if I addressed the whole of your question? Looking at Figure 3 in my paper might also help as it shows the flow of the dimensions through each layer. I appreciate that it is really complicated to keep track of all the dimensions, even for me as the one who designed it so I'm impressed you followed everything through as you did.

Emma :)

weicheng113 commented 2 years ago

Thanks a lot for the detailed explanation, Emma. I will need some time to digest your comments and also re-read the relevant part in the paper. Sometime the code appears to be more concrete to me than the paper, but the paper gives high level understanding. Thanks.

weicheng113 commented 2 years ago

Hi Emma,

With your comments, figure 3, re-reading of the paper and confirming with the code, I feel I have good understanding of temp_pointwise now. I summarize my understanding below. The description of Figure 3 below is a bit dry(I wrote it against your figure 3). If you have time, please help check my understanding. Any suggestions are highly appreciated. I wrote it with LyX tool, which is easier for formulas.

1650851210(1)

Thanks, Cheng

EmmaRocheteau commented 2 years ago

That’s great! I’ve gone through it carefully and it is all correct, you have really understood it well. As a quick test, maybe you can tell me what happens to the mask features (the second channel in the original data which indicates how recent the measurement was taken)?

weicheng113 commented 2 years ago

Thanks a lot Emma for your time, I am not sure if I understand your test.

a. The mask is paired with its corresponding feature. A grouped kernal 2*kernel_size will do dot product with feature mask pair across time_range=kernel_size(past up till current timepoint t).

b. I have not looked in details about the calculation of mask decay fields in preprocessing part(my classmate is working on this part). I originally thought it was indicator of whether a field has measurement value(or presence). But I can see there is more details about it in paper and corresponding preprocessing code.

If I did not understand your question correctly, please let me know. Thanks.

EmmaRocheteau commented 2 years ago

a. Yes you’re right. Sorry I wasn’t clear! I meant the test was to see that you’d understood that the mask features only appear as a second channel in layer one. But they disappear after that (I.e. they are not propagated forwards as skip connections unlike the feature values themselves). This mostly because the added pointwise features don’t have mask variables.

weicheng113 commented 2 years ago

Thanks Emma. My understanding is in the first layer, the groups effect C^n is on feature mask pair. From second layer onwards, the groups is applied on covolution_channels(13=12 conv + 1 skip connection in 2th layer) of each normal feature and the groups is applied on the repeated values of each pointwise feature. Feature mask pair in the first layer can be regarded as special two channels for a feature.

Ok, maybe it can be incorporated into the second layer onwards(for example 13 becomes 14 = 12 conv + 1 skip connection + mask field in 2th layer) and repeat pointwise 14 times in 2th layer(but as you point out there is no corresponding mask field here to match).

EmmaRocheteau commented 2 years ago

Yes exactly. You’ve understood it very thoroughly!

weicheng113 commented 2 years ago

Thank you very much for your time and guidance, Emma. I am very lucky to have confirmation and feedback from the paper author.