Closed tsly123 closed 2 years ago
I have figured it out.
The error is the size of support_size
and query_size
. In the original code, these 2 params are for per class. In my case, I want the sizes to be flexible.
I made some modifications to the code for later reference.
def test():
support_size_total = 5
nclasses = 5
query_size_total = 8
support_size = int(support_size_total // nclasses)
query_size = query_size_total
s_feat = torch.rand(support_size_total, 128)
q_feat = torch.rand(query_size, 128)
features = torch.cat([s_feat,q_feat], 0)
feat = embedding_propagation(features, alpha=0.5, rbf_scale=1, norm_prop=False)
support_labels = torch.arange(nclasses, device=feat.device).view(1, nclasses).repeat(support_size, 1).view(support_size, nclasses)
unlabeled_labels = nclasses * torch.ones(query_size, dtype=support_labels.dtype, device=support_labels.device).view(-1, query_size)
labels = torch.cat([support_labels, unlabeled_labels], 1).view(-1)
logits = label_propagation(feat, labels, nclasses, alpha=0.2, rbf_scale=1, norm_prop=True, apply_log=True)
logits = logits[int(support_size*nclasses):, ...].view(-1, nclasses)
return logits
Thanks.
Thanks for the investigation!
Thank you for sharing your work.
I'm using embedding and propagation for my work. To reproduce error, I add
test()
inside the https://github.com/ServiceNow/embedding-propagation/blob/master/embedding_propagation/embedding_propagation.pyand
The error is:
I traced it down and the error is in this line https://github.com/ServiceNow/embedding-propagation/blob/master/embedding_propagation/embedding_propagation.py#L74
The
features shapes
are similar to theget_logits()
in https://github.com/ServiceNow/embedding-propagation/blob/master/src/models/finetuning.py#L95I don't know what's wrong here. Could you help me? Thank you.