Trustworthy-ML-Lab / Label-free-CBM

A new framework to transform any neural networks into an interpretable concept-bottleneck-model (CBM) without needing labeled concept data
73 stars 13 forks source link

some questions for train #8

Closed Mr-Monday closed 5 months ago

Mr-Monday commented 6 months ago

Hello!

In train_cbm.py, from L157 to L173, why use the val target features to delete concepts that are not interpretable? Do this, ' W_c = proj_layer.weight[interpretable]; proj_layer = torch.nn.Linear(in_features=target_features.shape[1], out_features=len(concepts), bias=False); proj_layer.load_state_dict({"weight":W_c})' , the parameter W_c in this code is obtained based on the val set and then used for traing train_c = proj_layer(target_features.detach()) (i.e. L179). So why should use the parameters (W_c) obtained from the val set for model training (from L171 to L203)? Will this cause the leakage of val data?

Thank you in advance for your support and guidance!

tuomaso commented 5 months ago

W_c is not obtained with validation data, the parameters are only trained on the training data. The only loss that is backpropagated is loss on training data (see L128-L131). The validation data is only used for early stopping (line 133-143), and to filter which concepts were not learned properly on L157-161. We used validation data instead of train data for concept filtering to measure if the concept representations(W_c) we learn generalizes to new data, and deletes concepts that overfit to training data. Ideally these should use a separate validation dataset from the one we evaluate on which we did not do but I don't think there is a significant effect from using validation data this way(not used for learning any parameters).

Mr-Monday commented 5 months ago

Thank you for your reply.