Closed YuSawan closed 1 year ago
This is because the number of Soundtrack mentions in training set is 50, and we cannot sample 100shot for this type.
For sample support set, you can refer the code:
class Sampler:
def __init__(self, N, K, samples, classes):
self.K = K
self.N = N
self.samples = samples
self.classes = classes
self.support_idx = []
def __get_candidates__(self, target_classes):
return [idx for idx, sample in enumerate(self.samples) if target_classes in sample['class_count'] and idx not in self.support_idx]
def __next__(self):
support_class = {}
self.support_idx = []
for label in self.classes:
support_class[label] = 0
for label in self.classes:
while support_class[label] < self.K:
candidates = self.__get_candidates__(label)
if len(candidates) + support_class[label] <= self.K:
for index in candidates:
if index not in self.support_idx:
support_class[label] += 1
self.support_idx.append(index)
break
else:
index = random.choice(candidates)
if index not in self.support_idx:
support_class[label] += 1
self.support_idx.append(index)
return None, self.support_idx, None
def __iter__(self):
return self
def getsupportset(N,K,dataset,labels):
sampler = Sampler(N,K,dataset,labels)
print(N)
print(K)
print(labels)
data = []
for i in tqdm.tqdm(range(10)):
_, support_idx, _ = sampler.__next__()
data.append(support_idx)
return data
def getdata(file,K=5):
print(file)
outfile = file
dataset = []
labels = []
nums = {}
with open(file+'/train.json') as f:
for line in tqdm.tqdm(f):
line = json.loads(line)
class_count = {}
for entity in line['entity']:
if entity['type'] not in labels:
labels.append(entity['type'])
nums[entity['type']] = 1
else:
nums[entity['type']] += 1
if entity['type'] not in class_count:
class_count[entity['type']] = 1
else:
class_count[entity['type']] += 1
line['class_count'] = class_count
dataset.append(line)
nums = sorted(nums.items(),key=lambda x:x[1],reverse=True)
print(nums)
labels = [i[0] for i in nums]
supports = getsupportset(len(labels),K,dataset,labels)
print('sentencenum:',len(supports[0]))
target_label = labels
with open(outfile+'/'+str(K)+'shot.json','w') as f:
for support in supports:
s = []
for index in support:
s.append(dataset[index])
f.write(json.dumps({'support':s,'target_label':target_label})+'\n')
I could get few-samples with my dataset. Thank you so much!
I checked your provided data to make few-shot data with other datasets, but it seems their data size is slightly different per support data. (also, there are only 50 examples of "Soundtrack" in mit-movie1/100-shot.json)
Could you share with me how to make few-shot data?