kuprel / min-dalle

min(DALL·E) is a fast, minimal port of DALL·E Mini to PyTorch
MIT License
3.48k stars 256 forks source link

Add parallel forward (#74) #80

Closed neverix closed 1 year ago

neverix commented 1 year ago

Implements parallel forward. Also allows passing inputs smaller than 256. Full backwards compatibility

kuprel commented 1 year ago

Thank you for implementing parallel forward. This will be useful for finetuning. Have you tried generating with a grid size larger than 1 though? It fails for me.

neverix commented 1 year ago

Thanks for noticing the batch bug, should be ready to merge now

kuprel commented 1 year ago

Great this version works now thanks. I'll merge it. I want to make some minor changes to it though. Can you send me a script you're using to call this so that I know I didn't break it?

neverix commented 1 year ago

Sure, change the lines in min_dalle.py from L243 to:

                    attention_state=None,
                    prev_tokens=image_tokens[:i+1],
                    token_index=token_indices[:i+1],                    
kuprel commented 1 year ago

If I start with your pull request and make that modification to L43 of min_dalle.py, it fails when I run it with this command:

python image_from_text.py --text='artificial intelligence' --seed=7 --no-mega --grid-size=2
Traceback (most recent call last):
  File "/Users/kuprel/min-dalle-parallel-forward/image_from_text.py", line 70, in <module>
    generate_image(
  File "/Users/kuprel/min-dalle-parallel-forward/image_from_text.py", line 56, in generate_image
    image = model.generate_image(
  File "/Users/kuprel/min-dalle-parallel-forward/min_dalle/min_dalle.py", line 278, in generate_image
    return next(image_stream)
  File "/Users/kuprel/min-dalle-parallel-forward/min_dalle/min_dalle.py", line 258, in generate_image_stream
    for image in image_stream:
  File "/Users/kuprel/min-dalle-parallel-forward/min_dalle/min_dalle.py", line 239, in generate_raw_image_stream
    image_tokens[i + 1], attention_state = self.decoder.forward(
  File "/Users/kuprel/min-dalle-parallel-forward/min_dalle/models/dalle_bart_decoder.py", line 148, in forward
    prev_tokens = prev_tokens[list(range(image_count)) * 2]
IndexError: index 1 is out of bounds for dimension 0 with size 1
neverix commented 1 year ago

Huh, that's weird because it runs on my end (code to replicate)

generated

kuprel commented 1 year ago

Oh I must have mixed it up somehow. Now what I observe is that your pull request works with:

attention_state=None,
prev_tokens=image_tokens[:i+1],
token_index=token_indices[:i+1], 

but fails with

attention_state=attention_state,
prev_tokens=image_tokens[i],
token_index=token_indices[[i]]

I tried cloning your repository and saw the same thing. Do you observe this too?

neverix commented 1 year ago

Thanks, I fixed it again. I think something broke while adding batched generation

kuprel commented 1 year ago

Ok I confirmed it works now thanks

kuprel commented 1 year ago

I went through and streamlined the code now that we added parallel forward. Now to do parallel forward the arguments are

prev_tokens=image_tokens[:, :i+1],
token_index=token_indices[:i+1]

and normal forward is

prev_tokens=image_tokens[:, [i]],
token_index=token_indices[[i]]

Also DalleBartDecoder.forward now returns logits, and DalleBartDecoder.sample_tokens samples the logits and returns tokens.

Let me know if I broke anything.