openai / gpt-2

Code for the paper "Language Models are Unsupervised Multitask Learners"
https://openai.com/blog/better-language-models/
Other
22.33k stars 5.5k forks source link

interactive_conditional_samples not checking if prompt length is greater than hparams.n_ctx / 2 #121

Open albertwujj opened 5 years ago

albertwujj commented 5 years ago

If the length is greater this will break the model due to the word position embedding (wpe tensor) not being large enough. Should a check be added?

WuTheFWasThat commented 5 years ago

yeah feel free to send a PR!

albertwujj commented 5 years ago

Been busy learning RL, but for sure when I catch a break! Did you see my pull req https://github.com/openai/gpt-2/pull/119? Is it just not worth the risk of a bug? Not totally sure, but from my reading/mental diagram I think that it has literally ZERO effect on the output of the model.

albertwujj commented 5 years ago

I'll actually test it when I have time. It runs perfectly fine, just haven't tested on same input and seeds.

albertwujj commented 5 years ago

Or you can test better, you have more resources I think

WuTheFWasThat commented 5 years ago

reviewed! thanks for the work on it!

would be great if you could test it (we're pretty busy around here..)

albertwujj commented 5 years ago

sure, will get to it! I imagine you don't have the infra set up to test better than me actually either

MrKrzYch00 commented 5 years ago

It may also crash it. There is some error I'm getting when the input was ~460 words (GPT-2). Read it may be related to vocabulary but I'm not sure if this is the case here.

tensorflow.python.framework.errors_impl.InvalidArgumentError: indices[0,0] = 1024 is not in [0, 1024) [[{{node sample_sequence_1/while/model/GatherV2_1}}]]

MrKrzYch00 commented 5 years ago

Wouldn't the fix actually be to count tokens encoded and check if it did not exceed the length variable? We actually base it on encoded tokens not inputted text itself. I found out when I printed the value and it was exactly 512 it worked ok, >512 crashed. However I have a bit more modified code so I can just give some hints:

context_length = 0
while not raw_text or context_length <= 0 or context_length > length:
[...] [inside loop:]
context_tokens = enc.encode(raw_text)
context_length = len(context_tokens)

This fixed problems for me and it's no longer crashing!

I modified the original, pull request: #142