lllyasviel / ControlNet-v1-1-nightly

Nightly release of ControlNet 1.1
4.48k stars 364 forks source link

Input format for inpainting model #82

Closed lxa9867 closed 1 year ago

lxa9867 commented 1 year ago

Hi everyone, I got confused about the inpainting model input format.

I think it is different from the stablediffusion inpainting. I suppose it takes a (H,W,3) image as a condition and set the masked area as -1 as the control. I tried this when I finetune the inpainting model. However, it cannot work. I attached my code and hope that someone can let me know if I did anything wrong. Thanks.

class COCOHuman(Dataset):
    def __init__(self):
        self.data = glob.glob('/mnt/data/coco/human/masks/*_mask.jpg')

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        prompt = ""
        img = cv2.imread(item.replace('_mask', '').replace('masks', 'imgs'))
        gt_mask = cv2.imread(item)
        gt_mask[gt_mask > 1] = 1

        gt_mask = Image.fromarray(gt_mask)
        gt_mask = pad_image(gt_mask, (400, 400))
        bg = Image.new('RGB', (512, 512), (0, 0, 0))
        bg.paste(gt_mask, (56, 56))
        gt_mask = np.array(bg.convert('L'))

        # mask later
        source = copy.deepcopy(img)
        target = img

        # Do not forget that OpenCV read images in BGR order.
        source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
        target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)

        # random mask out area
        valid = np.where(gt_mask == 1)
        pos = np.random.choice(len(valid[0]))
        cx, cy = valid[0][pos], valid[1][pos]
        w, h = np.random.randint(100, 300), np.random.randint(100, 300)
        x, xx, y, yy = max(0, cx - w//2), min(512, cx + w//2), max(0, cy - h//2), min(512, cy + h//2)
        source[x:xx, y:yy] = -255.0

        # Normalize source images to [0, 1].
        source = source.astype(np.float32) / 255.0

        # Normalize target images to [-1, 1].
        target = (target.astype(np.float32) / 127.5) - 1.0

        return dict(jpg=target, txt=prompt, hint=source, gt_mask=gt_mask)
brappier commented 1 year ago

Hey, did you figure it out?

MuyuenHoshino commented 8 months ago

Hello, is there any progress?

geroldmeisinger commented 8 months ago

also see #89