amzn / convolutional-handwriting-gan

ScrabbleGAN: Semi-Supervised Varying Length Handwritten Text Generation (CVPR20)
https://www.amazon.science/publications/scrabblegan-semi-supervised-varying-length-handwritten-text-generation
MIT License
264 stars 55 forks source link

How to run inference for the generative model #7

Closed fabioperez closed 4 years ago

fabioperez commented 4 years ago

Hello!

I'm trying to use the pre-trained weights to generate handwritten text given arbitrary strings.

I managed to do that with something similar to the following code:

from PIL import Image
import torch
import numpy as np
from options.test_options import TestOptions
from models.BigGAN_networks import Generator

def load_model():
    opt = TestOptions().parse()  # get test options
    opt.n_classes = 80
    gen = Generator(**vars(opt))

    state_dict = torch.load('./latest_net_G.pth')
    if hasattr(state_dict, '_metadata'):
        del state_dict._metadata
    gen.load_state_dict(state_dict)
    gen.eval()
    return gen

model = load_model()

char_to_int = {
    'a': 18,
    'b': 21,
    'c': 15,
    ...,
}

def get_word(word):
    encoded = [char_to_int[char] for char in word]
    words = torch.zeros((1, len(encoded), 80), dtype=torch.int32)
    for i, code in enumerate(encoded):
        words[0, i, code] = 1
    return words

def generate_image(word):
    seed = np.random.randint(0, 10e4)
    word = get_word(word)
    z, _ = prepare_z_y(1, 128, 80, device='cpu', seed=seed)
    res = model.forward(z=z, y=words)
    res = res.detach().numpy()[0, 0] * 255
    im = Image.fromarray(res).convert('RGB')
    return im

img = generate_img('testing')

I have two questions:

  1. I do not have the exactly char_to_int encoding that maps characters to class labels. Where can I find it?
  2. Does that solution look OK for generating images?

Thanks!

sharonFogel commented 4 years ago

Hi, The code looks ok (I didn't try to run it). The encoding is according to the order of the characters in data/alphabets - the first character in the string is encoded to 1, the second to 2 etc. You can see it in the strLabelConverter class in models/OCR_network.py.

fabioperez commented 4 years ago

Thanks, data/alphabets is what I was looking for!

doulouUS commented 3 years ago

Hello!

I'm trying to use the pre-trained weights to generate handwritten text given arbitrary strings.

I managed to do that with something similar to the following code:

from PIL import Image
import torch
import numpy as np
from options.test_options import TestOptions
from models.BigGAN_networks import Generator

def load_model():
    opt = TestOptions().parse()  # get test options
    opt.n_classes = 80
    gen = Generator(**vars(opt))

    state_dict = torch.load('./latest_net_G.pth')
    if hasattr(state_dict, '_metadata'):
        del state_dict._metadata
    gen.load_state_dict(state_dict)
    gen.eval()
    return gen

model = load_model()

char_to_int = {
    'a': 18,
    'b': 21,
    'c': 15,
    ...,
}

def get_word(word):
    encoded = [char_to_int[char] for char in word]
    words = torch.zeros((1, len(encoded), 80), dtype=torch.int32)
    for i, code in enumerate(encoded):
        words[0, i, code] = 1
    return words

def generate_image(word):
    seed = np.random.randint(0, 10e4)
    word = get_word(word)
    z, _ = prepare_z_y(1, 128, 80, device='cpu', seed=seed)
    res = model.forward(z=z, y=words)
    res = res.detach().numpy()[0, 0] * 255
    im = Image.fromarray(res).convert('RGB')
    return im

img = generate_img('testing')

I have two questions:

  1. I do not have the exactly char_to_int encoding that maps characters to class labels. Where can I find it?
  2. Does that solution look OK for generating images?

Thanks!

Hello,

First, thank you for this very interesting work!

I reused exactly the code above with the following changes:

from data.alphabets import alphabetEnglish

...

char_to_int = {
    char:i 
    for i, char in enumerate(alphabetEnglish)
}

...

def generate_image(word):

    ...

    z = torch.from_numpy(np.zeros((128, 80))).float()

    ...

However, running it leads to a matrix shape inconsistency:

Traceback (most recent call last):
  File "./generate.py", line 100, in <module>
    img = generate_image('testing')
  File "./generate.py", line 95, in generate_image
    res = model.forward(z=z, y=words)
  File "/Users/A894EM/Documents/Computer-Vision/GANs/Handwriting_Generation/convolutional-handwriting-gan/models/BigGAN_networks.py", line 385, in forward
    h = block(h, ys[index])
  File "/Users/A894EM/Documents/Computer-Vision/GANs/Handwriting_Generation/convolutional-handwriting-gan/pytorch1.2/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/Users/A894EM/Documents/Computer-Vision/GANs/Handwriting_Generation/convolutional-handwriting-gan/models/BigGAN_layers.py", line 405, in forward
    h = self.activation(self.bn1(x, y))
  File "/Users/A894EM/Documents/Computer-Vision/GANs/Handwriting_Generation/convolutional-handwriting-gan/pytorch1.2/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/Users/A894EM/Documents/Computer-Vision/GANs/Handwriting_Generation/convolutional-handwriting-gan/models/BigGAN_layers.py", line 313, in forward
    gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
  File "/Users/A894EM/Documents/Computer-Vision/GANs/Handwriting_Generation/convolutional-handwriting-gan/pytorch1.2/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/Users/A894EM/Documents/Computer-Vision/GANs/Handwriting_Generation/convolutional-handwriting-gan/models/BigGAN_layers.py", line 123, in forward
    return F.linear(x, self.W_(), self.bias)
  File "/Users/A894EM/Documents/Computer-Vision/GANs/Handwriting_Generation/convolutional-handwriting-gan/pytorch1.2/lib/python3.7/site-packages/torch/nn/functional.py", line 1692, in linear
    output = input.matmul(weight.t())
RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x16 and 32x256)

It seems to happen at the second generator block, first batch norm layer. I did not modify the test options, apart from setting the --gpu_ids to -1 in the base options. I am using the pretrained model that you released. I feel that the issue should be in the configuration of the architecture input/output channels but I'm not too certain...

Anything I am missing here?

Thanks!

doulouUS commented 3 years ago

I actually missed the implementation of prepare_z_y in util/util.py, it works well. Thanks for the great work!

dagongji10 commented 3 years ago

@fabioperez I have modified your demo code like this:

from PIL import Image
import torch
import numpy as np
from options.test_options import TestOptions
from models.BigGAN_networks import Generator
from util.util import *

def load_model():
    opt = TestOptions().parse()  # get test options
    opt.n_classes = 80
    gen = Generator(**vars(opt))

    state_dict = torch.load('./pre-trained/latest_net_G.pth')
    if hasattr(state_dict, '_metadata'):
        del state_dict._metadata
    gen.load_state_dict(state_dict)
    gen.cuda()
    gen.eval()
    return gen

model = load_model()

alphabets_english = 'Only thewigsofrcvdampbkuq.A-210xT5\'MDL,RYHJ"ISPWENj&BC93VGFKz();#:!7U64Q8?+*ZX/%'
char_to_int = dict()
for ind, c in enumerate(alphabets_english):
    char_to_int[c] = ind

def get_word(word):
    encoded = [char_to_int[char] for char in word]
    words = torch.zeros((1, len(encoded), 80), dtype=torch.int32)
    for i, code in enumerate(encoded):
        words[0, i, code] = 1
    return words

def generate_image(word):
    seed = np.random.randint(0, 10e4)
    words = get_word(word)
    z, _ = prepare_z_y(1, 128, 80, device=torch.device('cuda'), seed=seed)
    res = model.forward(z=z, y=words.to(torch.device('cuda'))).cpu()
    res = res.detach().numpy()[0, 0] * 255
    im = Image.fromarray(res).convert('RGB')
    return im

import time
for i in range(10):
    t0 = time.time()
    img = generate_image('abcdefg12345')
    img.save('/data/{}.png'.format(i))
    t1 = time.time()
    print(t1-t0)

But my result got some problems:

  1. The input words contain number 0-9, but synthesis image show wrong char. And some char even can not show in the image, such as "," "." ;
  2. Synthesis image is not clean, there are shadows near some letters;

image

Do you meet these problems? How can I solve these?

sharonFogel commented 3 years ago

Our method indeed has a problem generating characters which appear only a few times in the dataset (e.g numbers in the IAM dataset), therefore the results are not surprising, maybe if you use a dataset with more numbers at least for the recognizer training the results would improve. The shadow is actually an artifact that can also be found in the original IAM dataset, so the network learns to mimic the appearance of the data and also contains this character background.