lucidrains / imagen-pytorch

Implementation of Imagen, Google's Text-to-Image Neural Network, in Pytorch
MIT License
8.09k stars 768 forks source link

How to print/display tensors as images? #221

Open VincentLu91 opened 2 years ago

VincentLu91 commented 2 years ago

I am testing the "Usage" example code from the README but I have not been successfully able to generate any meaningful images. It hasn't worked when I tried in a Colab notebook. Notebook (make sure GPU is selected as runtime type): https://colab.research.google.com/drive/1EsyMfsgQ5fzSMAccHWTbgTJ8st4RY6DS?usp=sharing

The example code doesn't show us how to print/display/plot an image after the tensor is created. The images variable is a 4D tensor value, so I've been looking into showing the tensor as an image. So far I tried

import torchvision.transforms as transforms
img = images[0]
# permute to match the desired memory format
img = img.permute(1, 2, 0).cpu().numpy()
plt.imshow(img)

But it only gives a grainy chart. I'm not sure why it's not generating anything meaningful but perhaps I am under training a model? Is it due to the code snippet below?

# mock images (get a lot of this) and text encodings from large T5

text_embeds = torch.randn(4, 256, 768).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# feed images into imagen, training each unet in the cascade

for i in (1, 2): # do I need to increase the number of iterations?
    loss = imagen(images, text_embeds = text_embeds, unet_number = i)
    loss.backward()

I haven't worked with 4D tensors but I would love to explore this model and get an understanding of the library in the process but I'm not sure if I'm on the right track.

lucalevi commented 2 years ago

Hey @VincentLu91 !

I tried the following and it worked for me:

from torchvision.utils import save_image img1 = images[0] #torch.Size([3, 256, 256]) save_image(img1, 'img1.png’)

This way it doesn't display/plot the image directly but it saves it in your folder.

As for the part

text_embeds = torch.randn(4, 256, 768).cuda() images = torch.randn(4, 3, 256, 256).cuda()

this is where it gets tricky, as far as I've understood. Here one should train one's own model, and so you have to pass to the variables the relative texts and images.

Maybe somebody with more experience will be able to provide further details!

vancuonghoang commented 1 year ago

use this:

images = imagen.sample(texts = [
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles'
], cond_scale = 3., return_pil_images = True)

count = 0
for image in images:
  image.save(f'{count}.png')
  count += 1 

u can add hyperparameter: return_pil_images = True

WaiterHsu commented 1 year ago

I used above method, but it's still appear grainy chart. So, i think have some problem when training, instead of tensor to image. Have anybody can help?