mahmoodlab / CLAM

Data-efficient and weakly supervised computational pathology on whole slide images - Nature Biomedical Engineering
http://clam.mahmoodlab.org
GNU General Public License v3.0
1.04k stars 346 forks source link

Using Feature Based training to detect lesion vs Normal on PANDA Dataset #16

Closed Ajaz-Ahmad closed 3 years ago

Ajaz-Ahmad commented 3 years ago

Hi Mahmood,

I really liked your work on MIL for fitting in the WSI image in memory was pretty motivating. I am working on the same technique as well. Had one or two questions about the same, if you can give suggestions from your experience it will help me a lot. -Objective-> Using slide level metadata to train the classifier [#MIL , weakly supervised learning]

-> Setting Data - PANDA (https://www.kaggle.com/c/prostate-cancer-grade-assessment) has 10k slides of different Gleason grading. I am considering negative (normal) grade as 0 and [4+4, 5+5] combined as positive class 1.

-> Now from the previous step, we now have a binary problem. -> Feature Extraction Stage-> I used resnet50 with preprocess input function available in keras, the features were extracted from conv4_block6_out [?, 32,32,1024] with input tiles of shape (512,512,3). Tiles are obtained from the WSI image and I used annotations available with data to filter background tiles and only tiles which had more than 70% tissue were used.

-> Training Network: I am using simple Attention Network to classify the two classes from features but the network either predicts 0 for one epoch and 1 for other with accuracy 56% and 44% respectively. The training data distribution is also the same. Below, is the network

data_input = Input(shape=input_dim, dtype='float32', name='input')
conv = Conv2D(filters=512,kernel_size=(1,1),strides=(2,2),padding='valid',activation='relu')(data_input)
x = Flatten()(conv)
fc2 = Dense(512, activation='relu', kernel_regularizer=l2(weight_decay), name='fc2')(x)                                                                                      
fc2 = Dropout(0.5)(fc2)
alpha = Mil_Attention(L_dim=128, output_dim=1, kernel_regularizer=l2(weight_decay), name='alpha', use_gated=useGated)(fc2)
x_mul = multiply([alpha, fc2])
out = Last_Sigmoid(output_dim=1, name='FC1_sigmoid')(x_mul)
model = Model(inputs=[data_input], outputs=[out])

Loss Metric:

class bag_loss(tf.keras.losses.Loss): def init(self, name='bag_loss'): super().init(name='bag_loss')

def call(self, y_true, y_pred):
    y_pred = K.mean(y_pred, axis=0, keepdims=False)
    y_pred = tf.squeeze(y_pred,axis=0)
    loss = K.binary_crossentropy(y_true, y_pred)
    loss = tf.keras.backend.cast(loss, dtype=tf.float32)
    return loss

Can you give any suggestions on this?

fedshyvana commented 3 years ago

Hi Ajaz, sorry I'm not at all familiar with deep learning implementations based on Keras so I can't really help you. Although if the network is either predicting all 1 class or another, it is possible that there's something wrong with your data sampling (e.g. your data loader is only sampling all positive examples in consecutive batches/epochs followed by all negative examples, etc.). I would suggest just double checking your entire code carefully for any potential bugs, and look at how the training/validation losses are behaving and whether the dataset is properly sampled during training, etc.

Max