Closed ktrapeznikov closed 3 years ago
Thanks for pointing out!
GeDi heuristics on combining the generative classifier outputs with LM logits were designed with greedy decoding in mind, so the experiments in our paper were all done with greedy decoding (Section 3.1.1), hence we don't support sampling as of now. Ensuring that sequences generated with sampling carry the desired attribute (eg. positive sentiment) will likely need some tuning of the decoding hyper-parameters (\omega, \rho, n) and some algorithmic changes.
Were you able to generate reasonable sequences with sampling using GeDi-guided generation?
Yeah. The output looks reasonable. It works pretty well for "topic" mode. I set the temperature = 1 and increase disc_weight=50. The class probability estimate is pretty high usually (in the 90s). Slightly worse for "sentiment" mode.
Looking at the generate code, the sampling is applied to the next_logit_prob
after they have been modulated by the GeDi probabilities.
@ktrapeznikov Do you want to create a pull request with this addition to the modeling_utils.py
file? Or I could make the changes in a commit.
top_k_top_p_filtering
added!
Awesome. Thanks. Somehow I missed your previous comment.
ImportError: cannot import name 'top_k_top_p_filtering' from 'transformers.generation_utils'
ImportError: cannot import name 'top_k_top_p_filtering' from 'transformers.generation_utils'
So, in the previous code, the top_k_top_p_filtering was imported from the transformers.generation_utils. I changed underscore to . as shown in following and it solved my problem:
from transformers.generation.utils import top_k_top_p_filtering
Also you can downgrade, and it will run too.
pip install transformers==4.36.2
so when
run_generation
script is used withdo_sample
, get an error becausetop_k_top_p_filtering
is missing.So I just added the function (from https://github.com/huggingface/transformers/blob/master/src/transformers/generation_utils.py)