ictnlp / DialoFlow

Code for ACL 2021 main conference paper "Conversations Are Not Flat: Modeling the Dynamic Information Flow across Dialogue Utterances".
MIT License
93 stars 10 forks source link

How to reproduce #13

Closed FlyingCat-fa closed 2 years ago

FlyingCat-fa commented 2 years ago

I reproduce the DialoFlow base in DailyDialog, the evaluation results are: NIST: [2.9148, 3.3919, 3.5077, 3.5375] BLEU: [0.4535, 0.2323, 0.1367, 0.086] METEOR: 0.1479778034868275 Entropy: [6.250107671407306, 8.663223223839859, 9.603956363262926, 9.959120587252972] Distinct: [0.08599954617653732, 0.32188216456202917] avg_len: 9.154005934718102

The results are lower than the results shown in paper.

Can you show the detail of fine-tune in DailyDialog?

My setting is: Training: Batch_size 16 (4 GPU , per_gpu_batch_size=4) gradient_accumulation_steps 1 epoch 50

The best Validation loss is 7.5168, at epoch 34.

generate: The Config parameters is default,and I set the beam_size=5.

And I did not use the Apex.

lizekang commented 2 years ago

My setting is Training: Batch size 64. You can use gradient_accumulation. epoch 50

You can select the checkpoint with the best response generation loss. In the paper, we select the model with about 2.30 validation loss.

And you can check the format of your prediction and the reference.

If you have any problems, please feel free to ask.

FlyingCat-fa commented 2 years ago

Thanks.

I train the DialogFlow base again, with the batch size = 64, epoch = 50. And select the checkpoint with the best response generation loss = 2.29, obtained in the epoch 50. I also use the Apex.

The results are: NIST-2: 3.2396 NIST-4: 3.3602 BLEU-2: 0.2197 BLEU-4: 0.0743 METEOR: 0.13976661322123995 Entropy-4: 9.867206665075539 Dist-1: 0.07665375927718619 Dist-2: 0.29163649529326574 avg_len: 9.195845697329377

The results are also lower than the results shown in paper.

I construct the test file from the multirefeval-master. For example, the test contexts constructed from the first two dialogues are: hey man , you wanna buy some weed ? hey man , you wanna buy some weed ? EOS some what ? hey man , you wanna buy some weed ? EOS some what ? EOS weed ! you know ? pot , ganja , mary jane some chronic ! hey man , you wanna buy some weed ? EOS some what ? EOS weed ! you know ? pot , ganja , mary jane some chronic ! EOS oh , umm , no thanks . hey man , you wanna buy some weed ? EOS some what ? EOS weed ! you know ? pot , ganja , mary jane some chronic ! EOS oh , umm , no thanks . EOS i also have blow if you prefer to do a few lines . hey man , you wanna buy some weed ? EOS some what ? EOS weed ! you know ? pot , ganja , mary jane some chronic ! EOS oh , umm , no thanks . EOS i also have blow if you prefer to do a few lines . EOS no , i am ok , really . hey man , you wanna buy some weed ? EOS some what ? EOS weed ! you know ? pot , ganja , mary jane some chronic ! EOS oh , umm , no thanks . EOS i also have blow if you prefer to do a few lines . EOS no , i am ok , really . EOS come on man ! i even got dope and acid ! try some ! hey man , you wanna buy some weed ? EOS some what ? EOS weed ! you know ? pot , ganja , mary jane some chronic ! EOS oh , umm , no thanks . EOS i also have blow if you prefer to do a few lines . EOS no , i am ok , really . EOS come on man ! i even got dope and acid ! try some ! EOS do you really have all of these drugs ? where do you get them from ? hey man , you wanna buy some weed ? EOS some what ? EOS weed ! you know ? pot , ganja , mary jane some chronic ! EOS oh , umm , no thanks . EOS i also have blow if you prefer to do a few lines . EOS no , i am ok , really . EOS come on man ! i even got dope and acid ! try some ! EOS do you really have all of these drugs ? where do you get them from ? EOS i got my connections ! just tell me what you want and i ' ll even give you one ounce for free . hey man , you wanna buy some weed ? EOS some what ? EOS weed ! you know ? pot , ganja , mary jane some chronic ! EOS oh , umm , no thanks . EOS i also have blow if you prefer to do a few lines . EOS no , i am ok , really . EOS come on man ! i even got dope and acid ! try some ! EOS do you really have all of these drugs ? where do you get them from ? EOS i got my connections ! just tell me what you want and i ' ll even give you one ounce for free . EOS sounds good ! let ' s see , i want . hey man , you wanna buy some weed ? EOS some what ? EOS weed ! you know ? pot , ganja , mary jane some chronic ! EOS oh , umm , no thanks . EOS i also have blow if you prefer to do a few lines . EOS no , i am ok , really . EOS come on man ! i even got dope and acid ! try some ! EOS do you really have all of these drugs ? where do you get them from ? EOS i got my connections ! just tell me what you want and i ' ll even give you one ounce for free . EOS sounds good ! let ' s see , i want . EOS yeah ? the taxi drivers are on strike again . the taxi drivers are on strike again . EOS what for ? the taxi drivers are on strike again . EOS what for ? EOS they want the government to reduce the price of the gasoline .

The predicitons from the DialoFlow generate.py are:

sure, what do you want to buy? I'm not sure yet. you know, I've never smoked before, but I've been thinking about it. that sounds like a good idea. What do you think of it? what do you want to get? I can give you some. how much do you want me to pay for it? how much do you want to spend? I can give you a discount. that sounds like fun. What do you want to get? well, i got them from a friend of mine. really? that's great! i'll take it! do you want to try some of my new stuff? how much do you want to get for a gram? are you kidding? The taxi drivers are working overtime. the government is trying to clamp down on the competition. that's stupid. They should do something about it.

And I also write the references from multirefeval-master to five files.

Could you find some errors?

lizekang commented 2 years ago

What about the reference file? You can see that there is a space before the punctuation in the reference but there is no space in your prediction file.

FlyingCat-fa commented 2 years ago

Thanks.

I set the clean_up_tokenization_spaces=False in GPT2Tokenizer.decode() and predict again.

The predicitons are: sure , what do you want to buy ? I'm not sure yet . you know , I've never smoked before , but I've been thinking about it . that sounds like a good idea . What do you think of it ? what do you want to get ? I can give you some . how much do you want me to pay for it ? how much do you want to spend ? I can give you a discount . that sounds like fun . What do you want to get ? well , i got them from a friend of mine . really ? that ' s great ! i ' ll take it ! do you want to try some of my new stuff ? how much do you want to get for a gram ? are you kidding ? The taxi drivers are working overtime . the government is trying to clamp down on the competition . that ' s stupid . They should do something about it .

Then the results also has the space in my prediction file. but I found some abbreviations also has not space, such as "I'm", "I've" in line1 and line2. But in line "really ? that ' s great ! i ' ll take it !", the "that ' s" and "i ' ll" are splited as 3 words.

Could you tell me how you generate, or how to add the space.

I run the generate.py for generation, and transformers=3.0.2

The evaluation results are: NIST-2: 3.1878 NIST-4: 3.304 BLEU-2: 0.2164 BLEU-4: 0.0721 METEOR: 0.14717688674636875 Entropy-4: 9.872796155919488 Dist-1: 0.037200750563552334 Dist-2: 0.1993201866046486 avg_len: 11.781454005934718

The references in ref_0.txt are: some what ? weed ! you know ? pot , ganja , mary jane some chronic ! oh , umm , no thanks . i also have blow if you prefer to do a few lines . no , i am ok , really . come on man ! i even got dope and acid ! try some ! do you really have all of these drugs ? where do you get them from ? i got my connections ! just tell me what you want and i 'll even give you one ounce for free . sounds good ! let 's see , i want . yeah ? i want you to put your hands behind your head ! you are under arrest ! what for ? they want the government to reduce the price of the gasoline . it is really a hot potato .

lizekang commented 2 years ago

You can use nltk.word_tokenize() to tokenize the sentence and then concatenate the tokens.

lizekang commented 2 years ago

What is your evaluation scripts?

FlyingCat-fa commented 2 years ago

Thanks.

I will try the nltk tokenizer.

My evaluation scripts almost are same as the dstc eval in DialoGPT.


import re
from collections import defaultdict
import argparse
from pathlib import Path

import os, time, subprocess, io, sys, re, argparse
import numpy as np

py_version = sys.version.split('.')[0]
if py_version == '2':
    open = io.open
else:
    unicode = str

def makedirs(fld):
    if not os.path.exists(fld):
        os.makedirs(fld)

cur_dir = str(Path(__file__).parent)

def str2bool(s):
    # to avoid issue like this: https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
    if s.lower() in ['t','true','1','y']:
        return True
    elif s.lower() in ['f','false','0','n']:
        return False
    else:
        raise ValueError

def calc_nist_bleu(path_refs, path_hyp, fld_out='temp', n_lines=None):
    # call mteval-v14c.pl
    # ftp://jaguar.ncsl.nist.gov/mt/resources/mteval-v14c.pl
    # you may need to cpan install XML:Twig Sort:Naturally String:Util 

    makedirs(fld_out)

    if n_lines is None:
        n_lines = len(open(path_refs[0], encoding='utf-8').readlines())    
    # import pdb; pdb.set_trace()
    _write_xml([''], fld_out + '/src.xml', 'src', n_lines=n_lines)
    _write_xml([path_hyp], fld_out + '/hyp.xml', 'hyp')#, n_lines=n_lines)
    _write_xml(path_refs, fld_out + '/ref.xml', 'ref')#, n_lines=n_lines)

    time.sleep(1)
    cmd = [
        'perl',f'{cur_dir}/3rdparty/mteval-v14c-20190801/mteval-v14c.pl',
        '-s', '%s/src.xml'%fld_out,
        '-t', '%s/hyp.xml'%fld_out,
        '-r', '%s/ref.xml'%fld_out,
        ]
    process = subprocess.Popen(cmd, stdout=subprocess.PIPE)
    # import pdb; pdb.set_trace()
    output, error = process.communicate()

    lines = output.decode().split('\n')

    try:
        nist = lines[-6].strip('\r').split()[1:5]
        bleu = lines[-4].strip('\r').split()[1:5]
        return [float(x) for x in nist], [float(x) for x in bleu]

    except Exception:
        print('mteval-v14c.pl returns unexpected message')
        print('cmd = '+str(cmd))
        print(output.decode())
        print(error.decode())
        return [-1]*4, [-1]*4

def calc_cum_bleu(path_refs, path_hyp):
    # call multi-bleu.pl
    # https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/multi-bleu.perl
    # the 4-gram cum BLEU returned by this one should be very close to calc_nist_bleu
    # however multi-bleu.pl doesn't return cum BLEU of lower rank, so in nlp_metrics we preferr calc_nist_bleu
    # NOTE: this func doesn't support n_lines argument and output is not parsed yet

    process = subprocess.Popen(
            ['perl', f'{cur_dir}/multi-bleu.perl'] + path_refs, 
            stdout=subprocess.PIPE, 
            stdin=subprocess.PIPE
            )
    with open(path_hyp, encoding='utf-8') as f:
        lines = f.readlines()
    for line in lines:
        process.stdin.write(line.encode())
    output, error = process.communicate()
    return output.decode()

def calc_meteor(path_refs, path_hyp, fld_out='temp', n_lines=None, pretokenized=True):
    # Call METEOR code.
    # http://www.cs.cmu.edu/~alavie/METEOR/index.html

    makedirs(fld_out)
    path_merged_refs = fld_out + '/refs_merged.txt'
    _write_merged_refs(path_refs, path_merged_refs)
    cmd = [
            'java', '-Xmx1g',    # heapsize of 1G to avoid OutOfMemoryError
            '-jar', f'{cur_dir}/3rdparty/meteor-1.5/meteor-1.5.jar', 
            path_hyp, path_merged_refs, 
            '-r', '%i'%len(path_refs),     # refCount 
            '-l', 'en', '-norm'     # also supports language: cz de es fr ar
            ]
    # print(cmd)
    process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    output, error = process.communicate()
    for line in output.decode().split('\n'):
        if "Final score:" in line:
            return float(line.split()[-1])

    print('meteor-1.5.jar returns unexpected message')
    print("cmd = " + " ".join(cmd))
    print(output.decode())
    print(error.decode())
    return -1 

def calc_entropy(path_hyp, n_lines=None):
    # based on Yizhe Zhang's code
    etp_score = [0.0,0.0,0.0,0.0]
    counter = [defaultdict(int),defaultdict(int),defaultdict(int),defaultdict(int)]
    i = 0
    for line in open(path_hyp, encoding='utf-8'):
        i += 1
        words = line.strip('\n').split()
        for n in range(4):
            for idx in range(len(words)-n):
                ngram = ' '.join(words[idx:idx+n+1])
                counter[n][ngram] += 1
        if i == n_lines:
            break

    for n in range(4):
        total = sum(counter[n].values())
        for v in counter[n].values():
            etp_score[n] += - v /total * (np.log(v) - np.log(total))

    return etp_score

def calc_len(path, n_lines):
    l = []
    for line in open(path, encoding='utf8'):
        l.append(len(line.strip('\n').split()))
        if len(l) == n_lines:
            break
    return np.mean(l)

def calc_diversity(path_hyp):
    tokens = [0.0,0.0]
    types = [defaultdict(int),defaultdict(int)]
    for line in open(path_hyp, encoding='utf-8'):
        words = line.strip('\n').split()
        for n in range(2):
            for idx in range(len(words)-n):
                ngram = ' '.join(words[idx:idx+n+1])
                types[n][ngram] = 1
                tokens[n] += 1
    div1 = len(types[0].keys())/tokens[0]
    div2 = len(types[1].keys())/tokens[1]
    return [div1, div2]

def nlp_metrics(path_refs, path_hyp, fld_out='temp',  n_lines=None):
    nist, bleu = calc_nist_bleu(path_refs, path_hyp, fld_out, n_lines)
    meteor = calc_meteor(path_refs, path_hyp, fld_out, n_lines)
    entropy = calc_entropy(path_hyp, n_lines)
    div = calc_diversity(path_hyp)
    avg_len = calc_len(path_hyp, n_lines)
    return nist, bleu, meteor, entropy, div, avg_len

def _write_merged_refs(paths_in, path_out, n_lines=None):
    # prepare merged ref file for meteor-1.5.jar (calc_meteor)
    # lines[i][j] is the ref from i-th ref set for the j-th query

    lines = []
    for path_in in paths_in:
        lines.append([line.strip('\n') for line in open(path_in, encoding='utf-8')])

    with open(path_out, 'w', encoding='utf-8') as f:
        for j in range(len(lines[0])):
            for i in range(len(paths_in)):
                f.write(unicode(lines[i][j]) + "\n")

def _write_xml(paths_in, path_out, role, n_lines=None):
    # prepare .xml files for mteval-v14c.pl (calc_nist_bleu)
    # role = 'src', 'hyp' or 'ref'

    lines = [
        '<?xml version="1.0" encoding="UTF-8"?>',
        '<!DOCTYPE mteval SYSTEM "">',
        '<!-- generated by https://github.com/golsun/NLP-tools -->',
        '<!-- from: %s -->'%paths_in,
        '<!-- as inputs for ftp://jaguar.ncsl.nist.gov/mt/resources/mteval-v14c.pl -->',
        '<mteval>',
        ]

    for i_in, path_in in enumerate(paths_in):

        # header ----

        if role == 'src':
            lines.append('<srcset setid="unnamed" srclang="src">')
            set_ending = '</srcset>'
        elif role == 'hyp':
            lines.append('<tstset setid="unnamed" srclang="src" trglang="tgt" sysid="unnamed">')
            set_ending = '</tstset>'
        elif role == 'ref':
            lines.append('<refset setid="unnamed" srclang="src" trglang="tgt" refid="ref%i">'%i_in)
            set_ending = '</refset>'

        lines.append('<doc docid="unnamed" genre="unnamed">')

        # body -----

        if role == 'src':
            body = ['__src__'] * n_lines
        else:
            with open(path_in, 'r', encoding='utf-8') as f:
                body = f.readlines()
            if n_lines is not None:
                body = body[:n_lines]
        #for i in range(len(body)):
        i = 0
        for b in body:
            line = b.strip('\n')
            line = line.replace('&',' ').replace('<',' ')        # remove illegal xml char
            # if len(line) > 0:
            lines.append('<p><seg id="%i"> %s </seg></p>'%(i + 1, line))
            i += 1

        # ending -----

        lines.append('</doc>')
        if role == 'src':
            lines.append('</srcset>')
        elif role == 'hyp':
            lines.append('</tstset>')
        elif role == 'ref':
            lines.append('</refset>')

    lines.append('</mteval>')
    with open(path_out, 'w', encoding='utf-8') as f:
        f.write(unicode('\n'.join(lines)))

def dialogue_evaluation(hyp_file, ref_file, fld_out):
    nist, bleu, meteor, entropy, div, avg_len = nlp_metrics(ref_file, hyp_file, fld_out)
    results = {
        'NIST-2': nist[1],
        'NIST-4': nist[3],
        'BLEU-2': bleu[1],
        'BLEU-4': bleu[3],
        'METEOR': meteor,
        'Entropy-4': entropy[3],
        'Dist-1': div[0],
        'Dist-2': div[1],
        'avg_len': avg_len
    }
    return results

def write_metrics(metric_results, output_path):
    with open(output_path, 'w') as output:
        output.write('filename ')
        output.write(output_path + '\n')

        for metric_name, metric in metric_results.items():
            # Write the metric to file.
            m = metric_name + ':    ' + str(metric)
            output.write(m + '\n')
        output.write('\n')

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--refs_dir', default='multirefs')
    parser.add_argument('--ref_file', default=None)
    parser.add_argument('--hyp_file', default='DialoFlow_DialyDialog_generated.txt', required=False)
    parser.add_argument('--fld_out',default='temp_DialoFlow', required=False)

    args = parser.parse_args()
    if args.ref_file is not None:
        refs_files = [args.ref_file]
    else:
        refs_files = list(map(str, Path(args.refs_dir).glob('ref_*.txt')))
    print("references: ", refs_files)

    metric_results = dialogue_evaluation(args.hyp_file, refs_files, args.fld_out)
    write_metrics(metric_results=metric_results, output_path=os.path.join(args.fld_out, 'metric_results_v3.txt'))
if __name__ == "__main__":
    main()
FlyingCat-fa commented 2 years ago

Thanks a lot.

I solved the problem by the nltk tokenizer.