karpathy / nanoGPT

The simplest, fastest repository for training/finetuning medium-sized GPTs.
MIT License
37.6k stars 6k forks source link

Stop words? #36

Open BoyuanJackChen opened 1 year ago

BoyuanJackChen commented 1 year ago

I'm trying to use nanoGPT to generate Python code, and I don't find a stop words implementation in the code right now, so what I'm getting is this:

Write a hello world function in Python3. Generate only code and no human language. Stop with double new line.
def hello_world():
... print("Hello world")
Hello world, you are nice
Output:
Hello world, you are nice
More about Python 3
Python 3 is a pretty new version of Python.
It has a lot new features like multithreading

I did make a naive change to the models.py in GPT.generate, as shown below. But it's not working.

if idx_next in self.stop:
    break

I wonder if there's any way to let it learn to stop with "\n\n" without fine tuning.

VatsaDev commented 1 year ago

The best I could do was generate, then string partition, This works fine for short inference, but is terrible for long ones

luminousking commented 4 months ago

@BoyuanJackChen Same problem here, have you solved the stop word issue? thx!

BoyuanJackChen commented 4 months ago

@BoyuanJackChen Same problem here, have you solved the stop word issue? thx!

Yes. If you are using huggingface transformers library, then you can do the following:

from transformers import StoppingCriteria, StoppingCriteriaList

class StopSequences(LogitsProcessor):
        def __init__(self, stop_ids, batch_size, encounters=1, eos_token_id=2):
            StoppingCriteria.__init__(self)
            self.stop_sequences = stop_ids
            self.batch_size = batch_size
            self.encounters = [encounters] * batch_size
            self.NUM_ENCOUNTERS = encounters
            self.eos_token_id = eos_token_id

        def __call__(self, input_ids, scores):
            forced_eos = torch.full((scores.size(1),), -float("inf"))
            forced_eos[self.eos_token_id] = 0
            for stop in self.stop_sequences:
                # Check if the input_ids end with the stop sequence
                for i in range(self.batch_size):
                    if self.encounters[i] <= 0:
                        continue
                    if input_ids[i][-len(stop):].tolist() == stop:
                        self.encounters[i] -= 1
                        if self.encounters[i] <= 0:
                            scores[i] = forced_eos
            return scores

# Initialize your model and tokenizer
model = ...
tokenizer = ...
prompt_ids = ...

# An example of how stop_words_ids should look like
stop_words = ["\n#", "\n```\n"]
stop_words_ids = [[13,29937], [13,28956,13], [13,28956,30004]]
logits_processor = LogitsProcessorList([StopSequences(stop_words_ids, batch_size=1, encounters=1)])
answer_ids = model.generate(
                        **prompt_ids,
                        use_cache = True,
                        do_sample = False,
                        logits_processor = logits_processor
                    )