otiliastr / coarse-to-fine-curriculum

Coarse-to-Fine Curriculum Learning
Apache License 2.0
21 stars 3 forks source link

Is there a code for calculating the distance matrix from the columns of the predictor? (as in paper) #2

Closed nourheshamshaheen closed 8 months ago

nourheshamshaheen commented 8 months ago

Hello!

In the file hierarchy_computation.py, I noticed that the code calculates the distance matrix (for affinity clustering) from the confusion matrix of a trained model.

However, in the paper, you state that there's a better way to calculate this distance which is calculating the cosine similarities between the columns of the predictor of a trained model.

Can your release the function that does this?

otiliastr commented 8 months ago

Hi,

Thanks for the question. It's basically using the weights of the last layer of the model. The last layer projects from the previous layer to each class, so it is a matrix of shape #units_previous_layer x #num_classes. So you can use each column of this matrix as representation of the corresponding class. Then you can use this representation and compute cosine distance with the other classes.

Here is how I computed that:

def compute_class_weight_distance(trainer):
    """The distance between different classes is given by the difference in their representations,
    where the representations are chosen to be the weights of the final layer of the network
    corresponding to each class.
    """
    # Make sure to handle the bias term correctly to make sure the weights in the last layer
    # are reliable class embeddings. The easiest way is to not include a bias term in the 
    # prediction layer. Here we assert that the bias is not included and there is only one
    # trainable variable.
    assert len(trainer.predictor.trainable_variables) == 1
    class_embeddings = trainer.predictor.trainable_variables[0].numpy().transpose()
    dist_matrix = sklearn.metrics.pairwise.cosine_distances(class_embeddings)
    return dist_matrix
nourheshamshaheen commented 8 months ago

Thank you for answering!