Closed Mr-Monday closed 5 months ago
W_c is not obtained with validation data, the parameters are only trained on the training data. The only loss that is backpropagated is loss on training data (see L128-L131). The validation data is only used for early stopping (line 133-143), and to filter which concepts were not learned properly on L157-161. We used validation data instead of train data for concept filtering to measure if the concept representations(W_c) we learn generalizes to new data, and deletes concepts that overfit to training data. Ideally these should use a separate validation dataset from the one we evaluate on which we did not do but I don't think there is a significant effect from using validation data this way(not used for learning any parameters).
Thank you for your reply.
Hello!
In train_cbm.py, from L157 to L173, why use the val target features to delete concepts that are not interpretable? Do this, ' W_c = proj_layer.weight[interpretable]; proj_layer = torch.nn.Linear(in_features=target_features.shape[1], out_features=len(concepts), bias=False); proj_layer.load_state_dict({"weight":W_c})' , the parameter W_c in this code is obtained based on the val set and then used for traing train_c = proj_layer(target_features.detach()) (i.e. L179). So why should use the parameters (W_c) obtained from the val set for model training (from L171 to L203)? Will this cause the leakage of val data?
Thank you in advance for your support and guidance!