gmftbyGMFTBY / Copyisallyouneed

[ICLR 2023] Codebase for Copy-Generator model, including an implementation of kNN-LM
https://openreview.net/forum?id=CROlOA9Nd8C&referrer=%5Bthe%20profile%20of%20Tian%20Lan%5D(%2Fprofile%3Fid%3D~Tian_Lan7)
MIT License
182 stars 22 forks source link

Question about process data of "encode doc"? #13

Open FutureWithoutEnding opened 11 months ago

FutureWithoutEnding commented 11 months ago

code in data/dpr_wikitext103_1024/encode_doc.py

def inference(**args):
    data = DPRDataset(args['data_path'])
    sampler = torch.utils.data.distributed.DistributedSampler(data)
    data_iter = DataLoader(data, batch_size=args['batch_size'], collate_fn=data.collate, sampler=sampler)
    sampler.set_epoch(0)

    text_lists, embeddings, size, counter = [], [], 0, 0
    for documents, labels in tqdm(data_iter):
        embed = inference_one_batch(documents)
        text_lists.extend(labels)
        embeddings.append(embed)
        size += len(embed)
        if len(embeddings) > args['cut_size']:
            embed = torch.cat(embeddings)
            torch.save((text_lists, embed), f'dpr_chunk_{args["local_rank"]}_{counter}.pt')
            counter += 1
            embeddings = []
    if len(embed) > 0:
        embed = torch.cat(embeddings)
        torch.save((text_lists, embed), f'dpr_chunk_{args["local_rank"]}_{counter}.pt')

this part of code is right? I think should 'clean' the text_lists when embeddings = [].