SamuelJoutard / Permutohedral_attention_module

Apache License 2.0
29 stars 12 forks source link

Try the Permutohedral latice #1

Closed jgsimard closed 5 years ago

jgsimard commented 5 years ago

Hi, First, very nice work! Second, I tried to you code of the permutohedral lattice as a simple filter and I think that there might be a problem because when I run this code on a 2d RGB image:

from permuthohedral_lattice import PermutohedralLattice

img = np.asarray(Image.open("small_input.bmp"))

indices = np.reshape(np.indices(img.shape[:2]), (2, -1))[None, :]
rgb = np.reshape(img, (3, -1))[None, :]

pl = PermutohedralLattice.apply

out = pl(torch.from_numpy(indices/5.0).cuda().float(),
         torch.from_numpy(rgb/0.125).cuda().float())

output = out.squeeze().cpu().numpy()
output = np.reshape(output, img.shape)
result = Image.fromarray((output/output.max() *255).astype(np.uint8))
result.save('out.bmp')

I get this image

out

I see two problems with this image : duplication of the image and horizontal black stripes. Do you know what might be causing this? Thanks

SamuelJoutard commented 5 years ago

Hi, First of all thank you for your comment! I am conducting experiments on my side to solve your issue. One thing I might think about first is that when you load an image the format is LxWx3. So before the line: rgb = np.reshape(img, (3, -1))[None, :] I think you should rearrange the order of your dimensions like: img = np.transpose(img, (2, 0, 1)). You might have similar considerations when reshaping your output. I have a test notebook where it works when I include the changes suggested above. I can share it with you if needed.

Hope this help.

PS: try this version of your code:


from permuthohedral_lattice import PermutohedralLattice

img = np.asarray(Image.open("small_input.bmp"))

indices = np.reshape(np.indices(img.shape[:2]), (2, -1))[None, :]
img = np.transpose(img, (2, 0, 1))
rgb = np.reshape(img, (3, -1))[None, :]

pl = PermutohedralLattice.apply

out = pl(torch.from_numpy(indices/5.0).cuda().float(),
         torch.from_numpy(rgb/0.125).cuda().float())

output = out.squeeze().cpu().numpy()
output = np.transpose(output, (1, 0))
output = np.reshape(output, (img.shape[1], img.shape[2], 3))
result = Image.fromarray((output/output.max() *255).astype(np.uint8))
result.save('out.bmp')```
jgsimard commented 5 years ago

Thanks, for the quick response! First problem is solved!

SamuelJoutard commented 5 years ago

You still have black stripes with the code above? Can you upload the input image and filtered image you obtain?

jgsimard commented 5 years ago

of course!

The code

from permuthohedral_lattice import PermutohedralLattice

img = np.asarray(Image.open("small_input.bmp"))

indices = np.reshape(np.indices(img.shape[:2]), (2, -1))[None, :]
img = np.transpose(img, (2, 0, 1))
rgb = np.reshape(img, (3, -1))[None, :]

pl = PermutohedralLattice.apply

out = pl(torch.from_numpy(indices).cuda().float(),
         torch.from_numpy(rgb/100).cuda().float())

output = out.squeeze().cpu().numpy()
output = np.transpose(output, (1, 0))
output = np.reshape(output, (img.shape[1], img.shape[2], 3))
result = Image.fromarray((output/output.max() *255).astype(np.uint8))
result.save('out.png')

Image before

small_input_2

Image after

out

SamuelJoutard commented 5 years ago

Well somehow I run a very similar code but I don't get those black stripes:

from PL_sym import PermutohedralLattice
import numpy as np
import cv2
import torch
import matplotlib.pyplot as plt

im = cv2.imread("elephant.jpg")
indices = np.reshape(np.indices(im.shape[:2]), (2, -1))[None, :]
im = np.transpose(im, (2, 0, 1))
rgb = np.reshape(im, (3, -1))[None, :]

pl = PermutohedralLattice.apply

out = pl(torch.from_numpy(indices/5.0).cuda().float(),
         torch.from_numpy(rgb/0.125).cuda().float())

output = out.squeeze().cpu().numpy()
output = np.transpose(output, (1, 0))
output = np.reshape(output, (848, 1272, 3))

plt.imshow(output/output.max())
plt.imshow(np.transpose(im, (1, 2, 0)))

And here are the images shown: image

I will check if I made any changes since I uploaded the code but to me the issue must come from array manipulation. I will keep you updated.

jgsimard commented 5 years ago

Thanks!

SamuelJoutard commented 5 years ago

Hi,

After checking, the code has not significantly changed since the online version. I will close this issue because I believe those stripes are due to array manipulation issues but please do not hesitate to give me an update if you find out something.

Thank you!