kohya-ss / sd-scripts

Apache License 2.0
5.31k stars 880 forks source link

Retain alpha in `pil_resize` for `--alpha_mask` #1619

Closed emcmanus closed 2 months ago

emcmanus commented 2 months ago

Currently pil_resize() drops the alpha channel when --alpha_mask is supplied, but only if the image width does not exceed the bucket size.

This codepath is entered on the last line, here:

def trim_and_resize_if_required(
    random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int]
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]:
    image_height, image_width = image.shape[0:2]
    original_size = (image_width, image_height)  # size before resize

    if image_width != resized_size[0] or image_height != resized_size[1]:
        # リサイズする
        if image_width > resized_size[0] and image_height > resized_size[1]:
            image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)  # INTER_AREAでやりたいのでcv2でリサイズ
        else:
            image = pil_resize(image, resized_size)
kohya-ss commented 2 months ago

Thank you for this!

Maru-mee commented 2 months ago

私の認識が間違っていなければ、 この変更は、sd3のみで、dev版には反映されていないようです。 しかし、dev版でも同じ事象(※1)が発生する問題のようなので、もし可能ならマージをお願いしたいです。 PR#1632と関係する要素であり、先に解決しておきたい課題です。

※1 下記のような事象です。 pilによるアルファチャンネル喪失、3チャンネル化 → alpha_mask作成時に if image.shape[2] == 4:にならず、 else: alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W] に分岐し強制停止。

kohya-ss commented 2 months ago

devブランチにも同様の変更を行いました。

Maru-mee commented 2 months ago

ありがとうございます!