Closed FlyingCat-fa closed 2 years ago
My setting is Training: Batch size 64. You can use gradient_accumulation. epoch 50
You can select the checkpoint with the best response generation loss. In the paper, we select the model with about 2.30 validation loss.
And you can check the format of your prediction and the reference.
If you have any problems, please feel free to ask.
Thanks.
I train the DialogFlow base again, with the batch size = 64, epoch = 50. And select the checkpoint with the best response generation loss = 2.29, obtained in the epoch 50. I also use the Apex.
The results are: NIST-2: 3.2396 NIST-4: 3.3602 BLEU-2: 0.2197 BLEU-4: 0.0743 METEOR: 0.13976661322123995 Entropy-4: 9.867206665075539 Dist-1: 0.07665375927718619 Dist-2: 0.29163649529326574 avg_len: 9.195845697329377
The results are also lower than the results shown in paper.
I construct the test file from the multirefeval-master. For example, the test contexts constructed from the first two dialogues are: hey man , you wanna buy some weed ? hey man , you wanna buy some weed ? EOS some what ? hey man , you wanna buy some weed ? EOS some what ? EOS weed ! you know ? pot , ganja , mary jane some chronic ! hey man , you wanna buy some weed ? EOS some what ? EOS weed ! you know ? pot , ganja , mary jane some chronic ! EOS oh , umm , no thanks . hey man , you wanna buy some weed ? EOS some what ? EOS weed ! you know ? pot , ganja , mary jane some chronic ! EOS oh , umm , no thanks . EOS i also have blow if you prefer to do a few lines . hey man , you wanna buy some weed ? EOS some what ? EOS weed ! you know ? pot , ganja , mary jane some chronic ! EOS oh , umm , no thanks . EOS i also have blow if you prefer to do a few lines . EOS no , i am ok , really . hey man , you wanna buy some weed ? EOS some what ? EOS weed ! you know ? pot , ganja , mary jane some chronic ! EOS oh , umm , no thanks . EOS i also have blow if you prefer to do a few lines . EOS no , i am ok , really . EOS come on man ! i even got dope and acid ! try some ! hey man , you wanna buy some weed ? EOS some what ? EOS weed ! you know ? pot , ganja , mary jane some chronic ! EOS oh , umm , no thanks . EOS i also have blow if you prefer to do a few lines . EOS no , i am ok , really . EOS come on man ! i even got dope and acid ! try some ! EOS do you really have all of these drugs ? where do you get them from ? hey man , you wanna buy some weed ? EOS some what ? EOS weed ! you know ? pot , ganja , mary jane some chronic ! EOS oh , umm , no thanks . EOS i also have blow if you prefer to do a few lines . EOS no , i am ok , really . EOS come on man ! i even got dope and acid ! try some ! EOS do you really have all of these drugs ? where do you get them from ? EOS i got my connections ! just tell me what you want and i ' ll even give you one ounce for free . hey man , you wanna buy some weed ? EOS some what ? EOS weed ! you know ? pot , ganja , mary jane some chronic ! EOS oh , umm , no thanks . EOS i also have blow if you prefer to do a few lines . EOS no , i am ok , really . EOS come on man ! i even got dope and acid ! try some ! EOS do you really have all of these drugs ? where do you get them from ? EOS i got my connections ! just tell me what you want and i ' ll even give you one ounce for free . EOS sounds good ! let ' s see , i want . hey man , you wanna buy some weed ? EOS some what ? EOS weed ! you know ? pot , ganja , mary jane some chronic ! EOS oh , umm , no thanks . EOS i also have blow if you prefer to do a few lines . EOS no , i am ok , really . EOS come on man ! i even got dope and acid ! try some ! EOS do you really have all of these drugs ? where do you get them from ? EOS i got my connections ! just tell me what you want and i ' ll even give you one ounce for free . EOS sounds good ! let ' s see , i want . EOS yeah ? the taxi drivers are on strike again . the taxi drivers are on strike again . EOS what for ? the taxi drivers are on strike again . EOS what for ? EOS they want the government to reduce the price of the gasoline .
The predicitons from the DialoFlow generate.py are:
sure, what do you want to buy? I'm not sure yet. you know, I've never smoked before, but I've been thinking about it. that sounds like a good idea. What do you think of it? what do you want to get? I can give you some. how much do you want me to pay for it? how much do you want to spend? I can give you a discount. that sounds like fun. What do you want to get? well, i got them from a friend of mine. really? that's great! i'll take it! do you want to try some of my new stuff? how much do you want to get for a gram? are you kidding? The taxi drivers are working overtime. the government is trying to clamp down on the competition. that's stupid. They should do something about it.
And I also write the references from multirefeval-master to five files.
Could you find some errors?
What about the reference file? You can see that there is a space before the punctuation in the reference but there is no space in your prediction file.
Thanks.
I set the clean_up_tokenization_spaces=False in GPT2Tokenizer.decode() and predict again.
The predicitons are: sure , what do you want to buy ? I'm not sure yet . you know , I've never smoked before , but I've been thinking about it . that sounds like a good idea . What do you think of it ? what do you want to get ? I can give you some . how much do you want me to pay for it ? how much do you want to spend ? I can give you a discount . that sounds like fun . What do you want to get ? well , i got them from a friend of mine . really ? that ' s great ! i ' ll take it ! do you want to try some of my new stuff ? how much do you want to get for a gram ? are you kidding ? The taxi drivers are working overtime . the government is trying to clamp down on the competition . that ' s stupid . They should do something about it .
Then the results also has the space in my prediction file. but I found some abbreviations also has not space, such as "I'm", "I've" in line1 and line2. But in line "really ? that ' s great ! i ' ll take it !", the "that ' s" and "i ' ll" are splited as 3 words.
Could you tell me how you generate, or how to add the space.
I run the generate.py for generation, and transformers=3.0.2
The evaluation results are: NIST-2: 3.1878 NIST-4: 3.304 BLEU-2: 0.2164 BLEU-4: 0.0721 METEOR: 0.14717688674636875 Entropy-4: 9.872796155919488 Dist-1: 0.037200750563552334 Dist-2: 0.1993201866046486 avg_len: 11.781454005934718
The references in ref_0.txt are: some what ? weed ! you know ? pot , ganja , mary jane some chronic ! oh , umm , no thanks . i also have blow if you prefer to do a few lines . no , i am ok , really . come on man ! i even got dope and acid ! try some ! do you really have all of these drugs ? where do you get them from ? i got my connections ! just tell me what you want and i 'll even give you one ounce for free . sounds good ! let 's see , i want . yeah ? i want you to put your hands behind your head ! you are under arrest ! what for ? they want the government to reduce the price of the gasoline . it is really a hot potato .
You can use nltk.word_tokenize() to tokenize the sentence and then concatenate the tokens.
What is your evaluation scripts?
Thanks.
I will try the nltk tokenizer.
My evaluation scripts almost are same as the dstc eval in DialoGPT.
import re
from collections import defaultdict
import argparse
from pathlib import Path
import os, time, subprocess, io, sys, re, argparse
import numpy as np
py_version = sys.version.split('.')[0]
if py_version == '2':
open = io.open
else:
unicode = str
def makedirs(fld):
if not os.path.exists(fld):
os.makedirs(fld)
cur_dir = str(Path(__file__).parent)
def str2bool(s):
# to avoid issue like this: https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
if s.lower() in ['t','true','1','y']:
return True
elif s.lower() in ['f','false','0','n']:
return False
else:
raise ValueError
def calc_nist_bleu(path_refs, path_hyp, fld_out='temp', n_lines=None):
# call mteval-v14c.pl
# ftp://jaguar.ncsl.nist.gov/mt/resources/mteval-v14c.pl
# you may need to cpan install XML:Twig Sort:Naturally String:Util
makedirs(fld_out)
if n_lines is None:
n_lines = len(open(path_refs[0], encoding='utf-8').readlines())
# import pdb; pdb.set_trace()
_write_xml([''], fld_out + '/src.xml', 'src', n_lines=n_lines)
_write_xml([path_hyp], fld_out + '/hyp.xml', 'hyp')#, n_lines=n_lines)
_write_xml(path_refs, fld_out + '/ref.xml', 'ref')#, n_lines=n_lines)
time.sleep(1)
cmd = [
'perl',f'{cur_dir}/3rdparty/mteval-v14c-20190801/mteval-v14c.pl',
'-s', '%s/src.xml'%fld_out,
'-t', '%s/hyp.xml'%fld_out,
'-r', '%s/ref.xml'%fld_out,
]
process = subprocess.Popen(cmd, stdout=subprocess.PIPE)
# import pdb; pdb.set_trace()
output, error = process.communicate()
lines = output.decode().split('\n')
try:
nist = lines[-6].strip('\r').split()[1:5]
bleu = lines[-4].strip('\r').split()[1:5]
return [float(x) for x in nist], [float(x) for x in bleu]
except Exception:
print('mteval-v14c.pl returns unexpected message')
print('cmd = '+str(cmd))
print(output.decode())
print(error.decode())
return [-1]*4, [-1]*4
def calc_cum_bleu(path_refs, path_hyp):
# call multi-bleu.pl
# https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/multi-bleu.perl
# the 4-gram cum BLEU returned by this one should be very close to calc_nist_bleu
# however multi-bleu.pl doesn't return cum BLEU of lower rank, so in nlp_metrics we preferr calc_nist_bleu
# NOTE: this func doesn't support n_lines argument and output is not parsed yet
process = subprocess.Popen(
['perl', f'{cur_dir}/multi-bleu.perl'] + path_refs,
stdout=subprocess.PIPE,
stdin=subprocess.PIPE
)
with open(path_hyp, encoding='utf-8') as f:
lines = f.readlines()
for line in lines:
process.stdin.write(line.encode())
output, error = process.communicate()
return output.decode()
def calc_meteor(path_refs, path_hyp, fld_out='temp', n_lines=None, pretokenized=True):
# Call METEOR code.
# http://www.cs.cmu.edu/~alavie/METEOR/index.html
makedirs(fld_out)
path_merged_refs = fld_out + '/refs_merged.txt'
_write_merged_refs(path_refs, path_merged_refs)
cmd = [
'java', '-Xmx1g', # heapsize of 1G to avoid OutOfMemoryError
'-jar', f'{cur_dir}/3rdparty/meteor-1.5/meteor-1.5.jar',
path_hyp, path_merged_refs,
'-r', '%i'%len(path_refs), # refCount
'-l', 'en', '-norm' # also supports language: cz de es fr ar
]
# print(cmd)
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, error = process.communicate()
for line in output.decode().split('\n'):
if "Final score:" in line:
return float(line.split()[-1])
print('meteor-1.5.jar returns unexpected message')
print("cmd = " + " ".join(cmd))
print(output.decode())
print(error.decode())
return -1
def calc_entropy(path_hyp, n_lines=None):
# based on Yizhe Zhang's code
etp_score = [0.0,0.0,0.0,0.0]
counter = [defaultdict(int),defaultdict(int),defaultdict(int),defaultdict(int)]
i = 0
for line in open(path_hyp, encoding='utf-8'):
i += 1
words = line.strip('\n').split()
for n in range(4):
for idx in range(len(words)-n):
ngram = ' '.join(words[idx:idx+n+1])
counter[n][ngram] += 1
if i == n_lines:
break
for n in range(4):
total = sum(counter[n].values())
for v in counter[n].values():
etp_score[n] += - v /total * (np.log(v) - np.log(total))
return etp_score
def calc_len(path, n_lines):
l = []
for line in open(path, encoding='utf8'):
l.append(len(line.strip('\n').split()))
if len(l) == n_lines:
break
return np.mean(l)
def calc_diversity(path_hyp):
tokens = [0.0,0.0]
types = [defaultdict(int),defaultdict(int)]
for line in open(path_hyp, encoding='utf-8'):
words = line.strip('\n').split()
for n in range(2):
for idx in range(len(words)-n):
ngram = ' '.join(words[idx:idx+n+1])
types[n][ngram] = 1
tokens[n] += 1
div1 = len(types[0].keys())/tokens[0]
div2 = len(types[1].keys())/tokens[1]
return [div1, div2]
def nlp_metrics(path_refs, path_hyp, fld_out='temp', n_lines=None):
nist, bleu = calc_nist_bleu(path_refs, path_hyp, fld_out, n_lines)
meteor = calc_meteor(path_refs, path_hyp, fld_out, n_lines)
entropy = calc_entropy(path_hyp, n_lines)
div = calc_diversity(path_hyp)
avg_len = calc_len(path_hyp, n_lines)
return nist, bleu, meteor, entropy, div, avg_len
def _write_merged_refs(paths_in, path_out, n_lines=None):
# prepare merged ref file for meteor-1.5.jar (calc_meteor)
# lines[i][j] is the ref from i-th ref set for the j-th query
lines = []
for path_in in paths_in:
lines.append([line.strip('\n') for line in open(path_in, encoding='utf-8')])
with open(path_out, 'w', encoding='utf-8') as f:
for j in range(len(lines[0])):
for i in range(len(paths_in)):
f.write(unicode(lines[i][j]) + "\n")
def _write_xml(paths_in, path_out, role, n_lines=None):
# prepare .xml files for mteval-v14c.pl (calc_nist_bleu)
# role = 'src', 'hyp' or 'ref'
lines = [
'<?xml version="1.0" encoding="UTF-8"?>',
'<!DOCTYPE mteval SYSTEM "">',
'<!-- generated by https://github.com/golsun/NLP-tools -->',
'<!-- from: %s -->'%paths_in,
'<!-- as inputs for ftp://jaguar.ncsl.nist.gov/mt/resources/mteval-v14c.pl -->',
'<mteval>',
]
for i_in, path_in in enumerate(paths_in):
# header ----
if role == 'src':
lines.append('<srcset setid="unnamed" srclang="src">')
set_ending = '</srcset>'
elif role == 'hyp':
lines.append('<tstset setid="unnamed" srclang="src" trglang="tgt" sysid="unnamed">')
set_ending = '</tstset>'
elif role == 'ref':
lines.append('<refset setid="unnamed" srclang="src" trglang="tgt" refid="ref%i">'%i_in)
set_ending = '</refset>'
lines.append('<doc docid="unnamed" genre="unnamed">')
# body -----
if role == 'src':
body = ['__src__'] * n_lines
else:
with open(path_in, 'r', encoding='utf-8') as f:
body = f.readlines()
if n_lines is not None:
body = body[:n_lines]
#for i in range(len(body)):
i = 0
for b in body:
line = b.strip('\n')
line = line.replace('&',' ').replace('<',' ') # remove illegal xml char
# if len(line) > 0:
lines.append('<p><seg id="%i"> %s </seg></p>'%(i + 1, line))
i += 1
# ending -----
lines.append('</doc>')
if role == 'src':
lines.append('</srcset>')
elif role == 'hyp':
lines.append('</tstset>')
elif role == 'ref':
lines.append('</refset>')
lines.append('</mteval>')
with open(path_out, 'w', encoding='utf-8') as f:
f.write(unicode('\n'.join(lines)))
def dialogue_evaluation(hyp_file, ref_file, fld_out):
nist, bleu, meteor, entropy, div, avg_len = nlp_metrics(ref_file, hyp_file, fld_out)
results = {
'NIST-2': nist[1],
'NIST-4': nist[3],
'BLEU-2': bleu[1],
'BLEU-4': bleu[3],
'METEOR': meteor,
'Entropy-4': entropy[3],
'Dist-1': div[0],
'Dist-2': div[1],
'avg_len': avg_len
}
return results
def write_metrics(metric_results, output_path):
with open(output_path, 'w') as output:
output.write('filename ')
output.write(output_path + '\n')
for metric_name, metric in metric_results.items():
# Write the metric to file.
m = metric_name + ': ' + str(metric)
output.write(m + '\n')
output.write('\n')
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--refs_dir', default='multirefs')
parser.add_argument('--ref_file', default=None)
parser.add_argument('--hyp_file', default='DialoFlow_DialyDialog_generated.txt', required=False)
parser.add_argument('--fld_out',default='temp_DialoFlow', required=False)
args = parser.parse_args()
if args.ref_file is not None:
refs_files = [args.ref_file]
else:
refs_files = list(map(str, Path(args.refs_dir).glob('ref_*.txt')))
print("references: ", refs_files)
metric_results = dialogue_evaluation(args.hyp_file, refs_files, args.fld_out)
write_metrics(metric_results=metric_results, output_path=os.path.join(args.fld_out, 'metric_results_v3.txt'))
if __name__ == "__main__":
main()
Thanks a lot.
I solved the problem by the nltk tokenizer.
I reproduce the DialoFlow base in DailyDialog, the evaluation results are: NIST: [2.9148, 3.3919, 3.5077, 3.5375] BLEU: [0.4535, 0.2323, 0.1367, 0.086] METEOR: 0.1479778034868275 Entropy: [6.250107671407306, 8.663223223839859, 9.603956363262926, 9.959120587252972] Distinct: [0.08599954617653732, 0.32188216456202917] avg_len: 9.154005934718102
The results are lower than the results shown in paper.
Can you show the detail of fine-tune in DailyDialog?
My setting is: Training: Batch_size 16 (4 GPU , per_gpu_batch_size=4) gradient_accumulation_steps 1 epoch 50
The best Validation loss is 7.5168, at epoch 34.
generate: The Config parameters is default,and I set the beam_size=5.
And I did not use the Apex.