laiyongkui1997 / FewJoint

9 stars 2 forks source link

few changes to be considered #1

Open niloofar17 opened 3 years ago

niloofar17 commented 3 years ago

Hi,

I was trying to run the few_joint_slu_1_bert.sh for snips and I had some errors to fix those I will give a short explanation:

class RawDataLoaderBase: def init(self, *args, **kwargs): pass

def load_data(self, path: str):
    pass

DataItem = collections.namedtuple("DataItem", ["seq_in", "seq_out", "label"])

class FewShotExample(object): """ Each few-shot example is a pair of (one query example, support set) """

def __init__(
        self,
        gid: int,
        batch_id: int,
        test_id: int,
        support_data_items: List[DataItem],
        test_data_item: DataItem
):
    self.gid = gid
    self.batch_id = batch_id
    self.test_id = test_id  # query relative index in one episode

    self.support_data_items = support_data_items  # all support data items
    self.test_data_item = test_data_item  # one query data items

def __str__(self):
    return self.__repr__()

def __repr__(self):
    return 'gid:{}\n\ttest_data:{}\n\ttest_label:{}\n\tsupport_data:{}'.format(
        self.gid,
        self.test_data_item.seq_in,
        self.test_data_item.seq_out,
        self.support_data_items,
    )

class FewShotRawDataLoader(RawDataLoaderBase): def init(self, opt): super(FewShotRawDataLoader, self).init() self.opt = opt self.debugging = opt.do_debug

def load_data(self, path: str) -> (List[FewShotExample], List[List[FewShotExample]], int):
    """
        load few shot data set
        input:
            path: file path
        output
            examples: a list, all example loaded from path
            few_shot_batches: a list, of fewshot batch, each batch is a list of examples
            max_len: max sentence length
        """
    with open(path, 'r') as reader:
        raw_data = json.load(reader)
        examples, few_shot_batches, max_support_size = self.raw_data2examples(raw_data)
    if self.debugging:
        examples, few_shot_batches = examples[:8], few_shot_batches[:2]
    return examples, few_shot_batches, max_support_size

def raw_data2examples(self, raw_data: Dict) -> (List[FewShotExample], List[List[FewShotExample]], int):
    """
    process raw_data into examples
    """
    examples = []
    all_support_size = []
    few_shot_batches = []
    # Notice: the batch here means few shot batch, not training batch
    for batch_id, batch in enumerate(raw_data):
        one_batch_examples = []
        support_data_items, test_data_items = self.batch2data_items(batch)
        all_support_size.append(len(support_data_items))
        ''' Pair each test sample with full support set '''
        for test_id, test_data_item in enumerate(test_data_items):
            gid = len(examples)
            example = FewShotExample(
                gid=gid,
                batch_id=batch_id,
                test_id=test_id,
                # domain_name=domain_n,
                test_data_item=test_data_item,
                support_data_items=support_data_items,
            )
            examples.append(example)
            one_batch_examples.append(example)
        few_shot_batches.append(one_batch_examples)
    max_support_size = max(all_support_size)
    return examples, few_shot_batches, max_support_size

def batch2data_items(self, batch: dict) -> (List[DataItem], List[DataItem]):
    support_data_items = self.get_data_items(parts=batch['support'])
    test_data_items = self.get_data_items(parts=batch['query'])
    return support_data_items, test_data_items

def get_data_items(self, parts: dict) -> List[DataItem]:
    data_item_lst = []
    for seq_in, seq_out, label in zip(parts['seq_ins'], parts['seq_outs'], parts['labels']):
        data_item = DataItem(seq_in=seq_in, seq_out=seq_out, label=label)
        data_item_lst.append(data_item)
    return data_item_lst