Closed neverix closed 1 year ago
Thank you for implementing parallel forward. This will be useful for finetuning. Have you tried generating with a grid size larger than 1 though? It fails for me.
Thanks for noticing the batch bug, should be ready to merge now
Great this version works now thanks. I'll merge it. I want to make some minor changes to it though. Can you send me a script you're using to call this so that I know I didn't break it?
Sure, change the lines in min_dalle.py
from L243 to:
attention_state=None,
prev_tokens=image_tokens[:i+1],
token_index=token_indices[:i+1],
If I start with your pull request and make that modification to L43 of min_dalle.py
, it fails when I run it with this command:
python image_from_text.py --text='artificial intelligence' --seed=7 --no-mega --grid-size=2
Traceback (most recent call last):
File "/Users/kuprel/min-dalle-parallel-forward/image_from_text.py", line 70, in <module>
generate_image(
File "/Users/kuprel/min-dalle-parallel-forward/image_from_text.py", line 56, in generate_image
image = model.generate_image(
File "/Users/kuprel/min-dalle-parallel-forward/min_dalle/min_dalle.py", line 278, in generate_image
return next(image_stream)
File "/Users/kuprel/min-dalle-parallel-forward/min_dalle/min_dalle.py", line 258, in generate_image_stream
for image in image_stream:
File "/Users/kuprel/min-dalle-parallel-forward/min_dalle/min_dalle.py", line 239, in generate_raw_image_stream
image_tokens[i + 1], attention_state = self.decoder.forward(
File "/Users/kuprel/min-dalle-parallel-forward/min_dalle/models/dalle_bart_decoder.py", line 148, in forward
prev_tokens = prev_tokens[list(range(image_count)) * 2]
IndexError: index 1 is out of bounds for dimension 0 with size 1
Huh, that's weird because it runs on my end (code to replicate)
Oh I must have mixed it up somehow. Now what I observe is that your pull request works with:
attention_state=None,
prev_tokens=image_tokens[:i+1],
token_index=token_indices[:i+1],
but fails with
attention_state=attention_state,
prev_tokens=image_tokens[i],
token_index=token_indices[[i]]
I tried cloning your repository and saw the same thing. Do you observe this too?
Thanks, I fixed it again. I think something broke while adding batched generation
Ok I confirmed it works now thanks
I went through and streamlined the code now that we added parallel forward. Now to do parallel forward the arguments are
prev_tokens=image_tokens[:, :i+1],
token_index=token_indices[:i+1]
and normal forward is
prev_tokens=image_tokens[:, [i]],
token_index=token_indices[[i]]
Also DalleBartDecoder.forward
now returns logits, and DalleBartDecoder.sample_tokens
samples the logits and returns tokens.
Let me know if I broke anything.
Implements parallel forward. Also allows passing inputs smaller than 256. Full backwards compatibility