chaiNNer-org / spandrel

Spandrel gives your project support for various PyTorch architectures meant for AI Super-Resolution, restoration, and inpainting. Based on the model support implemented in chaiNNer.
MIT License
139 stars 12 forks source link

the usage of MaskedImageModelDescriptor #146

Open whybfq opened 8 months ago

whybfq commented 8 months ago

does the following way of using the MaskedImageModelDescriptor is true or not, it works for LAMA while not work for MAT, any suggestions? Thanks a lot.

from contextlib import contextmanager
import torch.nn.functional as F

import gc
import cv2
import numpy as np
import torch
from PIL import Image
from spandrel import ModelLoader, ImageModelDescriptor, MaskedImageModelDescriptor
import os
import logging

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:100"
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
os.environ['NUMEXPR_NUM_THREADS'] = '1'
os.environ['HYDRA_FULL_ERROR'] = '1'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

torch.cuda.set_per_process_memory_fraction(0.85)  # 设置显存使用上限
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

logger = logging.getLogger(__name__)
device = torch.device("cuda:0")

@contextmanager
def collect_gc():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
    gc.collect()
    logger.info("Pre-execution garbage collection completed.")
    try:
        yield
    except Exception as e:
        logger.error(f"An error occurred: {e}")
        raise
    finally:
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
        gc.collect()
        logger.info("Post-execution garbage collection completed.")

def ceil_modulo(x, mod):
    if x % mod == 0:
        return x
    return (x // mod + 1) * mod

def pad_tensor_to_modulo(img, mod=8):
    batch_size, channels, height, width = img.shape
    out_height = ceil_modulo(height, mod)
    out_width = ceil_modulo(width, mod)
    return F.pad(img, pad=(0, out_width - width, 0, out_height - height), mode='reflect')

def main():
    model = ModelLoader(device).load_from_file(r"spandrel/tests/models/big-lama.pt")
    assert isinstance(model, MaskedImageModelDescriptor)
    # send it to the GPU and put it in inference mode
    model = model.to(device)
    model.eval()

    with collect_gc():
        img_lq = cv2.imread("inputs/22.jpg", cv2.IMREAD_COLOR).astype(np.float32) / 255.
        img_lq = np.transpose(img_lq if img_lq.shape[2] == 1 else img_lq[:, :, [2, 1, 0]],(2, 0, 1))  # HCW-BGR to CHW-RGB
        image_tensor = torch.from_numpy(img_lq).to(device).unsqueeze(0)  # CHW-RGB to NCHW-RGB

        mask_lq = cv2.imread("inputs/mask22.png", cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.
        mask_tensor = torch.from_numpy(mask_lq).float().to(device).unsqueeze(0).unsqueeze(0)  # (HxW to 1xHxW) to N1xHxW

        # Pillow version
        # image, mask = Image.open("inputs/27.jpg").convert('RGB'), Image.open("inputs/mask27.png").convert('L')
        # image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).unsqueeze(0).to(device).float().div(255.)
        # mask_tensor = torch.from_numpy(np.array(mask)).unsqueeze(0).unsqueeze(0).to(device)  # 添加额外的批量和通道维度

        # unpad_to_size = [image_tensor.shape[2], image_tensor.shape[3]]
        image_tensor, mask_tensor = pad_tensor_to_modulo(image_tensor), pad_tensor_to_modulo(mask_tensor)

        print(image_tensor.shape, "\n", mask_tensor.shape)  # Expected to be in NCHW format

        # inpaint model
        with torch.no_grad():
            output = model(image_tensor, mask_tensor)

        # save image for Pillow IMAGE
        inpainted_image = Image.fromarray((output.squeeze().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
        inpainted_image.save("inpaint.png")
        inpainted_image.show()

        # save image for opencv image
        output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()

        if output.ndim == 3:
            output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))  # CHW-RGB to HCW-BGR
        output = (output * 255.0).round().astype(np.uint8)  # float32 to uint8
        cv2.imwrite("output.jpg", output)

if __name__ == '__main__':
    main()
joeyballentine commented 8 months ago

Is there a specific error you get when running MAT? Knowing what error your code is throwing should be able to help us track down what's wrong. On our side, both implementations appear to work.

whybfq commented 8 months ago

Is there a specific error you get when running MAT? Knowing what error your code is throwing should be able to help us track down what's wrong. On our side, both implementations appear to work.

yes, of course. Sorry for not co clear in the beginning, after I changed the model path and the other things the same, model = ModelLoader(device).load_from_file(r"spandrel/tests/models/Places_512_FullData_G.pth")

there was an error like the following:

torch.Size([1, 3, 800, 800]) 
 torch.Size([1, 1, 800, 800])
An error occurred: shape '[1, 12, 8, 12, 8, 180]' is invalid for input of size 1800000
Traceback (most recent call last):
  File "/home/hangyi/Downloads/gaoqing/spandrel/interface.py", line 122, in <module>
    main()
  File "/home/hangyi/Downloads/gaoqing/spandrel/interface.py", line 105, in main
    output = model(image_tensor, mask_tensor)
  File "/home/hangyi/.local/lib/python3.10/site-packages/spandrel/__helpers/model_descriptor.py", line 342, in __call__
    output = self._call_fn(self.model, image, mask)
  File "/home/hangyi/.local/lib/python3.10/site-packages/spandrel/__helpers/model_descriptor.py", line 323, in <lambda>
    self._call_fn = call_fn or (lambda model, image, mask: model(image, mask))
  File "/home/hangyi/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/hangyi/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/hangyi/.local/lib/python3.10/site-packages/spandrel/architectures/MAT/arch/MAT.py", line 1607, in forward
    output = self.model(
  File "/home/hangyi/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/hangyi/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/hangyi/.local/lib/python3.10/site-packages/spandrel/architectures/MAT/arch/MAT.py", line 1583, in forward
    img = self.synthesis(images_in, masks_in, ws, noise_mode=noise_mode)
  File "/home/hangyi/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/hangyi/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/hangyi/.local/lib/python3.10/site-packages/spandrel/architectures/MAT/arch/MAT.py", line 1500, in forward
    out_stg1 = self.first_stage(images_in, masks_in, ws, noise_mode=noise_mode)
  File "/home/hangyi/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/hangyi/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
joeyballentine commented 8 months ago

I think your mask tensor is the wrong shape but I'm not 100% sure. I can look into this more later

whybfq commented 8 months ago

I think your mask tensor is the wrong shape but I'm not 100% sure. I can look into this more later

Thanks for your time.

whybfq commented 8 months ago

I thik I know thre reason, for MAT, I need to pad or resize the image to make its size a multiple of 512.

joeyballentine commented 8 months ago

Thanks for letting us know. Sorry I didn't end up looking more into it myself

@RunDevelopment we should confirm if that's the case and if so, set its size requirements accordingly

RunDevelopment commented 8 months ago

set its size requirements accordingly

It already is. The repo docs even say that it supports multiples of 512 (and not just exactly 512), so I have no idea why it doesn't work.