lucidrains / DALLE-pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch
MIT License
5.56k stars 643 forks source link

Image mask should be an argument to generate.py #300

Open afiaka87 opened 3 years ago

afiaka87 commented 3 years ago

I think "image mask engineering" is going to be about as equally important as people are finding prompt engineering. It's widely used in the Open AI blog post. Anyway, per usual my scatterbrain is constantly editing code and forgetting to upstream if it's worthwhile - here's a partial diff so I don't forget to merge this or if anyone else wants to I don't mind.

+parser.add_argument('--image_mask_path', type = str, help='Path to image to pass in as a mask during generation.')
+
 parser.add_argument('--hug', dest='hug', action = 'store_true')

 parser.add_argument('--chinese', dest='chinese', action = 'store_true')
@@ -51,6 +54,17 @@ parser.add_argument('--taming', dest='taming', action='store_true')

 args = parser.parse_args()

+image_mask_path = Path(args.image_mask_path)
+assert image_mask_path.exists(), f'{args.image_mask_path} could not be found.'
+
+pil_image = Image.open(image_mask_path)
+image_transform = T.Compose([
+    T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
+    T.RandomResizedCrop(256, scale=(1.0, 1.0), ratio=(1.0, 1.0)),
+    T.ToTensor(),
+])
+image_mask = image_transform(pil_image).cuda()
+image_mask_batch = image_mask.repeat(args.batch_size,1,1,1)
 # helper fns

 def exists(val):
@@ -101,7 +115,7 @@ for text in tqdm(texts):
     outputs = []

     for text_chunk in tqdm(text.split(args.batch_size), desc = f'generating images for - {text}'):
-        output = dalle.generate_images(text_chunk, filter_thres = args.top_k)
+        output = dalle.generate_images(text_chunk, filter_thres = args.top_k, img=image_mask_batch)
         outputs.append(output)

     outputs = torch.cat(outputs)
afiaka87 commented 3 years ago

@lucidrains How easy/possible would it be to use a custom mask "structure"? perhaps the target format could be whatever the typical coco-style segmentation data looks like; and then maybe you could abstract something on top of that which can generate a segment from e.g. a white background? Or even better; the inverse of the white background so e.g. the shape of the "mannequin" or what-not is all that is left unmasked.