shannanyinxiang / SPTS

Official implementation of SPTS: Single-Point Text Spotting (ACM MM 2022 Oral)
137 stars 12 forks source link

推理的时候无坐标的训练? #12

Closed yangxcccscsa closed 1 year ago

yangxcccscsa commented 1 year ago

您好,修改engine/val.py中的decode_pred_seq函数,让它只解码文本。这个具体应该怎么改???

shannanyinxiang commented 1 year ago

您好,可以尝试做如下修改:

def decode_pred_seq(index_seq, prob_seq, target, args):
    index_seq = index_seq[:-1]
    prob_seq = prob_seq[:-1]
    if len(index_seq) % 25 != 0:
        index_seq = index_seq[:-len(index_seq)%25]
        prob_seq = index_seq[:-len(index_seq)%25]

    decode_results = decode_seq(index_seq, 'none', args)
    confs = prob_seq.reshape(-1, 25).mean(-1)

    image_id = target['image_id']
    results = []
    for decode_result, conf in zip(decode_results, confs):
        recog = decode_result['recog']
        result = {
            'image_id': image_id,
            'rec': recog,
            'score': conf.item()
        }
        results.append(result)

    return results

def decode_seq(seq, type, args):
    seq = seq[seq != args.padding_index]
    if type == 'input':
        seq = seq[1:]
    elif type == 'output':
        seq = seq[:-1]
    elif type == 'none':
        seq = seq 
    else:
        raise ValueError
    seq = seq.reshape(-1, 25)

    decode_result = []
    for text_ins_seq in seq:
        recog = []
        for index in text_ins_seq:
            if index == args.recog_pad_index:
                break 
            if index == args.recog_pad_index - 1:
                continue
            recog.append(args.chars[index - args.num_bins])
        recog = ''.join(recog)
        decode_result.append(
            {'recog': recog}
        )
    return decode_result