Open inkzk opened 6 years ago
Please refer to https://github.com/yufengm/Adaptive/blob/4c0555af546cdbd49e99ff1bd6e91d1654ae0cd2/train.py#L152 for test on the validation dataset.
Can you share the test code?I tried to write the test section, but there were some problems.
import torch import matplotlib.pyplot as plt import numpy as np import argparse import pickle import os from torchvision import transforms from build_vocab import Vocabulary from adaptive import AttentiveCNN ,Decoder ,Encoder2Decoder from PIL import Image
from torch.autograd import Variable
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import os os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "1" #使用第二块GPU
def load_image(image_path, transform=None): image = Image.open(image_path) image = image.resize([224, 224], Image.LANCZOS)
if transform is not None:
image = transform(image).unsqueeze(0)
return image
def main(args):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])
# Load vocabulary wrapper
with open(args.vocab_path, 'rb') as f:
vocab = pickle.load(f)
# Build models
encoder = AttentiveCNN(args.embed_size, args.hidden_size).eval() # eval mode (batchnorm uses moving mean/variance)
encoder2decoder = Encoder2Decoder(args.embed_size, len(vocab), args.hidden_size)
# Load the trained model parameters
encoder2decoder.load_state_dict(torch.load(args.encoder2decoder_path))
# Prepare an image
image = load_image(args.image, transform)
# Generate an caption from the image
sampled_ids, attention, Beta = encoder2decoder.sampler(image)
sampled_ids=Variable(torch.LongTensor(sampled_ids))
sampled_ids = sampled_ids[0].cpu().numpy() # (1, max_seq_length) -> (max_seq_length)
# Convert word_ids to words
sampled_caption = []
for word_id in sampled_ids:
word = vocab.idx2word[word_id]
sampled_caption.append(word)
if word == '<end>':
break
sentence = ' '.join(sampled_caption)
# Print out the image and the generated caption
print (sentence)
image = Image.open(args.image)
plt.imshow(np.asarray(image))
if name == 'main': parser = argparse.ArgumentParser() parser.add_argument('--image', type=str, required=True, help='input image for generating caption') parser.add_argument('--encoder2decoder_path', type=str, default='models/adaptive-1.pkl', help='path for trained encoder2decoder') parser.add_argument('--vocab_path', type=str, default='data/vocab.pkl', help='path for vocabulary wrapper')
# Model parameters (should be same as paramters in train.py)
parser.add_argument('--embed_size', type=int, default=256, help='dimension of word embedding vectors')
parser.add_argument('--hidden_size', type=int, default=512, help='dimension of lstm hidden states')
parser.add_argument('--num_layers', type=int, default=1, help='number of layers in lstm')
args = parser.parse_args()
main(args)
RuntimeError:Expected object of type torch.LongTensor but found type torch.cuda.LongTensor for argument #3'index'
please help!!!
@inkzk
Did you later write the test module to generate image subtitles?Can you share it with me
@inkzk
Did you later write the test module to generate image subtitles?Can you share it with me
Did you solve it?
Could you please upload a example script?