Open CherishineNi opened 5 years ago
import argparse from PIL import Image import torch import matplotlib.pyplot as plt import numpy as np import cv2 import pickle from utils import CocoEvalLoader, to_var, show_images from adaptive import Encoder2Decoder from build_vocab import Vocabulary from torch.autograd import Variable from torchvision import transforms def main(): pretrained = 'models/adaptive-1.pkl' vocab_path = './data/vocab.pkl' with open(vocab_path, 'rb') as f: vocab = pickle.load(f)
model = Encoder2Decoder(256, len(vocab), 512) model.load_state_dict(torch.load(pretrained, map_location={'cuda:1':'cuda:0'})) model.eval() # Image transformation transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) image = Image.open('./data/test/1.jpg') image = image.resize([224, 224], Image.LANCZOS) image = transform(image).unsqueeze(0) image_tensor = Variable(image, volatile=True) generated_captions, _, _ = model.sampler(image_tensor) captions = generated_captions.data.numpy() sampled_ids = captions[0] sampled_caption = [] for word_id in sampled_ids: word = vocab.idx2word[word_id] if word == '<end>': break else: sampled_caption.append(word) sentence = ' '.join(sampled_caption) # return sentence print (sentence)
if name == 'main': main() 出错:TypeError: torch.index_select received an invalid combination of arguments - got (torch.FloatTensor, int, !torch.cuda.LongTensor!)
please help! Thanks
Can you share your test files? thank you very much!
import argparse from PIL import Image import torch import matplotlib.pyplot as plt import numpy as np import cv2 import pickle from utils import CocoEvalLoader, to_var, show_images from adaptive import Encoder2Decoder from build_vocab import Vocabulary from torch.autograd import Variable from torchvision import transforms def main(): pretrained = 'models/adaptive-1.pkl' vocab_path = './data/vocab.pkl' with open(vocab_path, 'rb') as f: vocab = pickle.load(f)
Define model and load pretrained
if name == 'main': main() 出错:TypeError: torch.index_select received an invalid combination of arguments - got (torch.FloatTensor, int, !torch.cuda.LongTensor!)
please help! Thanks