ZhengtongXu / UniT

UniT: Unified Tactile Representation for Robot Learning
MIT License
27 stars 3 forks source link

Testing VQVAE #1

Closed Jcastanyo closed 2 weeks ago

Jcastanyo commented 2 weeks ago

Hi!

First, congratulations on your incredible work!

I want to test your model (VQ-VAE) to reconstruct tactile images with your dataset and then with mine.

Is there any code to run the test? I only see code to run the training. I tried to change the "train" param in "vqvae_representation_key.yaml" from true to false. However, I'm getting an error.

Thanks in advance.

ZhengtongXu commented 2 weeks ago

Thanks for your interests in our work. For reconstruction, here is a simple example to forward vqgan:

  img_batch = {'image':image}
  rec_output, _ = vqgan(vqgan.get_input(img_batch,'image'))
  rec_output = rec_output.to(dtype=torch.float32) 
  rec_output = rec_output.squeeze(0).detach().cpu().numpy().transpose((1, 2, 0))
  rec_output = (rec_output * 255).astype(np.uint8)
Jcastanyo commented 2 weeks ago

Thanks for your quick answer!

Where is vqgan in your repo? Could you please share with me a script to run the test with a bit more of detail?

Thanks in advance!

ZhengtongXu commented 2 weeks ago

You may want to check this file https://github.com/ZhengtongXu/UniT/blob/main/UniT/taming/models/vqgan.py

Jcastanyo commented 2 weeks ago

Yes, that's the file I'm looking at. I'm able to run the training when calling to self.model.fit() but when I try to call self.model.test(), I'm getting the following error:

"TypeError: 'RepresentationWrapper' object is not iterable"

ZhengtongXu commented 2 weeks ago

You can try to use the code I just posted above instead of calling model.test()

Jcastanyo commented 2 weeks ago

I managed to solve it following your piece of code! Thank you very much! I attach here my code to run it in case it may be interesting for other people:

model = VQModel(**cfg.model)
ckptdir = "/UniT_shared/checkpoint-epoch=340.ckpt"
self.model.init_from_ckpt(ckptdir)

image = cv2.imread("/UniT_shared/data/images/0.png")
image = image.astype(np.float32) / 255.

resized_image = cv2.resize(image, (160, 128))

torch_image = torch.from_numpy(resized_image)
torch_image = torch.unsqueeze(torch_image, 0)

img_batch = {"image" : torch_image}
input = self.model.get_input(img_batch, "image")

rec_image, _  = self.model(input)
rec_image = rec_image.to(dtype=torch.float32)
rec_image = rec_image.squeeze(0).detach().cpu().numpy().transpose((1, 2, 0))
rec_image = (rec_image * 255).astype(np.uint8)