Puzer / stylegan-encoder

StyleGAN Encoder - converts real images to latent space
Other
1.07k stars 166 forks source link

Reduce encode_images.py time by using one model instance #54

Open rockdrigoma opened 3 years ago

rockdrigoma commented 3 years ago

Hi, I am trying to decrease generation time. So far I am getting 2 min and 20 seconds per image (generating 10 images as output for age). What I am realizing is that encode_images.py is taking this long for each input image:

  1. Initializing generator : 7.2106 secs
  2. Creating PerceptualModel : 9.0305 secs
  3. Loading Resnet model : 23.0473 secs
  4. Loop loss : 1.0582 secs
  5. Loop loss : 0.0619 secs
  6. Loop loss : 0.0630 secs
  7. Loop loss : 0.0618 secs
  8. Loop loss : 0.0628 secs
  9. Loop loss : 0.0621 secs

So I am trying to initialize the generator, create the perceptual model and load the resnet model at once at the beginning of my script and pass as parameters to encode_images.py so steps 1 to 3 are not being done for each image.

But I have no idea if that's the right way to do it. I defined an auxiliar() function instead of calling the script directly and passing same flags and parameters:

New defined function

auxiliar(optimizer='lbfgs', face_mask=True, iterations=6, use_lpips_loss=0, use_discriminator_loss=0, output_video=False, src_dir='aligned_images/', generated_images_dir='generated_images/', dlatent_dir='latent_representations/')

Former script call

python encode_images.py --optimizer=lbfgs --face_mask=True --iterations=6 --use_lpips_loss=0 --use_discriminator_loss=0 --output_video=False aligned_images/ generated_images/ latent_representations/

So far I am getting this error: ValueError: Tensor(“Const_1:0”, shape=(3,), dtype=float32) must be from the same graph as Tensor(“strided_slice:0", shape=(1, 256, 256, 3), dtype=float32).

At this point of the code that used to be in encode_images.py:

perceptual_model = PerceptualModel(args, perc_model=perc_model, batch_size=batch_size)
perceptual_model.build_perceptual_model(generator, discriminator_network)