orobix / Prototypical-Networks-for-Few-shot-Learning-PyTorch

Implementation of Prototypical Networks for Few Shot Learning (https://arxiv.org/abs/1703.05175) in Pytorch
MIT License
986 stars 210 forks source link

How do I make a prediction? #8

Closed ale316 closed 6 years ago

ale316 commented 6 years ago

Thank you for your work!

I've trained a model for a few epochs, and now I'd like to make predictions with it. I load it:

model = ProtoNet().cuda()
model.load_state_dict(torch.load('./output/best_model.pth'))

I load 15 labeled data points, for a total of 3 labels:

# x.size() -> (15)
# y.size() -> (15, 64, 64)
x, y = load_data()

I load a single datapoint I want to predict

to_predict = torch.Tensor(1, 64, 64)

I now would like to few-shot train on 5 examples per class and then predict a class for my to_predict. How do I go about that?

ale316 commented 6 years ago

Alright, I've been digging a little deeper into the code and the paper.

In order to get predictions, I need to pass to prototypical_loss function, the result of model(x), the correct labels target_y and the number of "shots" as n_support.

The last couple of lines in the loss function, are what gives us the predictions. We have:

_, y_hat = log_p_y.max(2)
acc_val = y_hat.eq(target_inds.squeeze()).float().mean()

Where, correct me if I'm wrong: log_p_y is a tensor of distances from the centroids, y_hat is a tensor of which ones are the closest ones to our query, but then what does y_hat.eq(target_inds.squeeze()).float() represent?

If the predictions were labels: 0, 0, 2, then it would look like:

tensor([[ 1.,  0.,  0.],
        [ 1.,  0.,  0.],
        [ 0.,  0.,  1.]])

What does that tensor represent?

ale316 commented 6 years ago

Sorry for the spam here, but for posterity:

I wrote a function that should return predictions given:

def predict(support_x, support_y, query_x, query_y=None):
    support_x = support_x.to('cpu')
    support_y = support_y.to('cpu')
    query_x = query_x.to('cpu')

    classes = torch.unique(support_y)
    n_classes = len(classes)
    n_query = len(query_x)

    # get a list of tensors of support_y for each class
    support_idxs = list(map(lambda c: support_y.eq(c).nonzero().squeeze(1), classes))

    # take the mean of tensors for each class to create a centroid
    prototypes = torch.stack([support_x[idx_list].mean(0) for idx_list in support_idxs])

    # finds the euclidean distances between each query_x and each centroid
    dists = euclidean_dist(query_x, prototypes)

    # run it through softmax
    log_p_y = F.log_softmax(-dists, dim=1)

    # lists the idx (label) of the closest centroid for each query_x
    _, y_hat = log_p_y.max(1)
    labels = [classes[i] for i in y_hat.squeeze()]

    return labels

Now, I still have some doubts:

dnlcrl commented 6 years ago

Sorry for the late reply, but if uncomment these lines in train.py: https://github.com/orobix/Prototypical-Networks-for-Few-shot-Learning-PyTorch/blob/8ed0c9a2de5a5281313927bd9db6ea592a23ec41/src/train.py#L247-L250 the code would automatically run the evaluation on the test dataset.

For what concerns your question,

Why do we even softmax the distances, instead of just taking the min?

You should read the linked paper as the answer lies (https://arxiv.org/pdf/1703.05175.pdf):

our method is most similar to the non-linear extension of NCA [27] because we use a neural network to perform the embedding and we optimize a softmax based on Euclidean distances in the transformed space, as opposed to a margin loss. A key distinction between our approach and non-linear NCA is that we form a softmax directly over classes, rather than individual points, computed from distances to each class’s prototype representation

In either case, embedded query points are classified via a softmax over distances to class prototypes

And finally:

Given a distance function d : RM × RM → [0, +∞), prototypical networks produce a distribution over classes for a query point x based on a softmax over distances to the prototypes in the embedding space

dnlcrl commented 6 years ago

I closed the issue, but feel free to reopen it/add comments if you want to, Thank you.

MarioProjects commented 5 years ago

Hi guys! I am at this point and I do not just understand it, I already have the trained model. For the test, what I understand that should be done is to calculate the embedding of my sample and see which centroid is closest to classify ... If I have trained with 5000 classes I do not understand at all as in the test phase it is necessary to pass a set of support and another of query.

Taking the implementation of @ale316 , predict (support_x, support_y, query_x, query_y = None), we will pass support sets of the training set and querys of which we do not know its y (therefore we equate it to None) and ... as takes into account the 5000 classes if the support only includes 30 (to say some number)

My idea is to create all the training embeddings and then generate the centroids -> After that generate the embeddings for the test samples and get the class with the distance euclidean to the centroids... I am correct? But you dont do that,

Sorry for the spam here, but for posterity:

I wrote a function that should return predictions given:

  • a tensor support_x of size (n_support, 1024)
  • a tensor support_y of size (n_support,)
  • a tensor query_x of size (n_query, 1024)
def predict(support_x, support_y, query_x, query_y=None):
    support_x = support_x.to('cpu')
    support_y = support_y.to('cpu')
    query_x = query_x.to('cpu')

    classes = torch.unique(support_y)
    n_classes = len(classes)
    n_query = len(query_x)

    # get a list of tensors of support_y for each class
    support_idxs = list(map(lambda c: support_y.eq(c).nonzero().squeeze(1), classes))

    # take the mean of tensors for each class to create a centroid
    prototypes = torch.stack([support_x[idx_list].mean(0) for idx_list in support_idxs])

    # finds the euclidean distances between each query_x and each centroid
    dists = euclidean_dist(query_x, prototypes)

    # run it through softmax
    log_p_y = F.log_softmax(-dists, dim=1)

    # lists the idx (label) of the closest centroid for each query_x
    _, y_hat = log_p_y.max(1)
    labels = [classes[i] for i in y_hat.squeeze()]

    return labels

Now, I still have some doubts:

  • Is this correct? Edit: yes it is
  • Why do we even softmax the distances, instead of just taking the min?
dnlcrl commented 5 years ago

@MarioProjects I think that in the validation phase you shouldn't use the centroids you got from the training phase, because the whole purpose of Few Shot Learning is to be able to correctly classifiy even classes never seen before, if you save the centroids you will never have the centroids for classes never seen before, so IMO in the validation phase you should "forget" the centroids and compute new ones along with the nearest centroid for each sample; while in test/production, where you usually have never seen before classes, you could either use the centorids for all the classes (train + valid) and classify the samples based on that, or you could monitor the output and add new classes to the dataset when you see new clusters forming.

stellaywu commented 4 years ago

Thanks for the library and @ale316 thanks for the implementation! At inference time there is no label, what would the support_y be?

Ashigarg123 commented 4 months ago

Any update on how to do inference to predict the class?