xinntao / Real-ESRGAN

Real-ESRGAN aims at developing Practical Algorithms for General Image/Video Restoration.
BSD 3-Clause "New" or "Revised" License
27.78k stars 3.49k forks source link

Unreaveling the image into dimensions is incorrect for 'mps' (m1 mac) device #442

Open Langhalsdino opened 2 years ago

Langhalsdino commented 2 years ago

What

When using the M1 mac on the current master commit e5763af5749430c9f7389f185cc53f90c4852ed5 and using the following environment environment-mac.yaml the resulting image is split into 3x3 tiles. This is due to an incorrect unreaveling for the M1 mac mps environment.

frame000008out

output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))

This only occures for the mps device, if i choose cpu everything works fine. So it probably due to another definition of rgb channel within the m1 (mps) pytroch backend.

How to reproduce / environment / ...

Error occurs in:

upsampler = RealESRGANer(
            scale=4,
            model_path=model_path,
            model=model,
            tile=tile,
            tile_pad=tile_pad,
            pre_pad=pre_pad,
            half=False,
            device='mps'
)
upsampler.enhance(img, outscale=outscale)

Everything works finde for:

upsampler = RealESRGANer(
            scale=4,
            model_path=model_path,
            model=model,
            tile=tile,
            tile_pad=tile_pad,
            pre_pad=pre_pad,
            half=False,
            device='cpu'
)
upsampler.enhance(img, outscale=outscale)

environment-mac.yaml

Quick hack to solve if

Since only the dimensions are incorrect, loading the image and changing them solves the issue. Sadly i could not find what part of the mps backend screws the dimensions up. So its probably the best to fix them at the end if device is mps.

image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = image.flatten().reshape(
    (image.shape[2], image.shape[0], image.shape[1])
).transpose((1, 2, 0))
cv2.imwrite(img_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))

My mind gets always a bit messed up with dimensions and reshaping arrays, so i am using the quickfix. But adding a lines to the utils #L205 solves the issue permanently.

output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
if self.device == 'mps':
    output_img = output_img.flatten().reshape( ...
xinntao commented 2 years ago

@Langhalsdino Thanks for reporting the issue.

It seems that you have found the solution! Can you help to contribute by opening a pull request to do the fix? Thanks 😄

simasima121 commented 1 year ago

Thanks for this.

Final part looks like this:

if str(self.device) == 'mps':
    output_img = output_img.flatten().reshape(
        output_img.shape[2], output_img.shape[0], output_img.shape[1]
    ).transpose((1, 2, 0))
    output_img = cv2.cvtColor(output_img, cv2.COLOR_RGB2BGR)
hholtmann commented 1 year ago

Applying the fix to the current master 5ca1078535923d485892caee7d7804380bfc87fd and executing the inference_realesrgan.py (without any parameters) with device set to "mps" processes the example images really fast. Unfortunately there are lines in the generated images. For the provided example images that contain alpha channels the image is still split into 3x3 tiles. Any ideas?

0014_out

children-alpha_out

BlueYellowGreen commented 1 year ago

According to this issue(PyTorch MPS Backend), the problem occurs when moving data from cpu to mps. It can be solved by fixing img.to('mps') to img.contiguous().to('mps').