lipiji / SongNet

Code for ACL 2020 paper "Rigid Formats Controlled Text Generation":https://www.aclweb.org/anthology/2020.acl-main.68/
MIT License
230 stars 40 forks source link

大神请问有没 top-k sampling 的tensorflow代码? #7

Closed guotong1988 closed 3 years ago

guotong1988 commented 3 years ago

@lipiji 多谢!!!

guotong1988 commented 3 years ago

https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/beam_search.py

guotong1988 commented 3 years ago

自己实现的top-k sampling感觉跟beam search在BLEU上差很多,,

guotong1988 commented 3 years ago

自己写的top-k sampling大概

import numpy as np

def softmax(x):
    x_row_max = x.max(axis=-1)
    x_row_max = x_row_max.reshape(list(x.shape)[:-1]+[1])
    x = x - x_row_max
    x_exp = np.exp(x)
    x_exp_row_sum = x_exp.sum(axis=-1).reshape(list(x.shape)[:-1]+[1])
    softmax = x_exp / x_exp_row_sum
    return softmax

logits_token = np.array([0.1,0.1,9.9,0.1,0.1,0.1])
top3index = np.argsort(logits_token)[-3:]
top3value = logits_token[top3index]
print(top3value)
print(top3index)
print(softmax(top3value))
top1index = np.random.choice(top3index, 1, p=softmax(top3value))
print(top1index)