ServiceNow / embedding-propagation

Codebase for Embedding Propagation: Smoother Manifold for Few-Shot Classification. This is a ServiceNow Research project that was started at Element AI.
Apache License 2.0
208 stars 21 forks source link

Problem/Question about Label propagation #28

Closed tsly123 closed 2 years ago

tsly123 commented 2 years ago

Thank you for sharing your work.

I'm using embedding and propagation for my work. To reproduce error, I add test() inside the https://github.com/ServiceNow/embedding-propagation/blob/master/embedding_propagation/embedding_propagation.py

def test():
    support_size = 1    # or 5; both yield the same error
    nclasses = 5
    query_size = 10

    s_feat = torch.rand(support_size * nclasses, 128)   # supports features
    q_feat = torch.rand(query_size, 128)   # queries features
    features = torch.cat([s_feat,q_feat], 0)

    feat = embedding_propagation(features, alpha=0.5, rbf_scale=1, norm_prop=False)
    support_labels = torch.arange(nclasses, device=feat.device).view(1, nclasses).repeat(support_size, 1).view(
        support_size, nclasses)
    unlabeled_labels = nclasses * torch.ones(query_size * nclasses, dtype=support_labels.dtype,
                                             device=support_labels.device).view(query_size, nclasses)
    labels = torch.cat([support_labels, unlabeled_labels], 0).view(-1)
    logits = label_propagation(feat, labels, nclasses, alpha=0.2, rbf_scale=1, norm_prop=True, apply_log=True)
    logits = logits.view(-1, nclasses, nclasses)[support_size:(support_size + query_size), ...].view(-1, nclasses)
    print(logits)
    return logits

and

print('features', feat.shape)    # torch.Size([15, 128])
print('support_labels', support_labels.shape)  # torch.Size([1, 5])
print('unlabeled_labels', unlabeled_labels.shape)  # torch.Size([10, 5])
print('labels', labels.shape)  # torch.Size([55])

The error is:

File "/project/hnguyen2/stly/code/fewshot/fsnoise/builder/label_prop.py", line 85, in label_propagation
    y_pred = torch.mm(propagator, labels)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (15x15 and 55x5)

I traced it down and the error is in this line https://github.com/ServiceNow/embedding-propagation/blob/master/embedding_propagation/embedding_propagation.py#L74

The features shapes are similar to the get_logits() in https://github.com/ServiceNow/embedding-propagation/blob/master/src/models/finetuning.py#L95

I don't know what's wrong here. Could you help me? Thank you.

tsly123 commented 2 years ago

I have figured it out. The error is the size of support_size and query_size. In the original code, these 2 params are for per class. In my case, I want the sizes to be flexible. I made some modifications to the code for later reference.

def test():

    support_size_total = 5
    nclasses = 5
    query_size_total = 8

    support_size = int(support_size_total // nclasses)
    query_size = query_size_total

    s_feat = torch.rand(support_size_total, 128)
    q_feat =  torch.rand(query_size, 128)
    features = torch.cat([s_feat,q_feat], 0)

    feat = embedding_propagation(features, alpha=0.5, rbf_scale=1, norm_prop=False)
    support_labels = torch.arange(nclasses, device=feat.device).view(1, nclasses).repeat(support_size, 1).view(support_size, nclasses)
    unlabeled_labels = nclasses * torch.ones(query_size, dtype=support_labels.dtype, device=support_labels.device).view(-1, query_size)
    labels = torch.cat([support_labels, unlabeled_labels], 1).view(-1)
    logits = label_propagation(feat, labels, nclasses, alpha=0.2, rbf_scale=1, norm_prop=True, apply_log=True)
    logits = logits[int(support_size*nclasses):, ...].view(-1, nclasses)

    return logits

Thanks.

prlz77 commented 2 years ago

Thanks for the investigation!