boschresearch / ISSA

Official implementation of "Intra-Source Style Augmentation for Improved Domain Generalization" (WACV 2023 & IJCV)
GNU Affero General Public License v3.0
34 stars 4 forks source link

What is the unresolved reference 'torch_utils' in train_encoder. py and the structure of the dataset? #5

Closed Liuhp133 closed 11 months ago

YumengLi007 commented 1 year ago

Hi, you firstly need to train a StyleGAN generator using e.g., code here. You could follow its instruction to set up datasets, torch_utils, etc. This repo only contains the training code for the GAN inversion encoder training. Please check how-to.pdf for more details as well.

LongVu219 commented 2 months ago

@YumengLi007 can you provide the structure of the data or at least data type, i am getting an error that it cant find train1119.npz, what is this npz file ? Is that latent code w plus that is generated using my own pretrain StyleGan3 model ?

YumengLi007 commented 2 months ago

Hi @LongVu219 , below you could find (roughly) how the data is saved. Basically, we generated and saved some fake images and their style latents

z = torch.from_numpy(np.random.RandomState(seed + round).randn(batch_size, G.z_dim)).to(device)
w_samples = G.mapping(z, None,truncation_psi=1.0) # 0.8
imgs = G.synthesis(w_samples)
img = np.asarray(imgs.cpu(), dtype=np.float32)
img = (img - lo) * (255 / (hi - lo))
img = np.rint(img).clip(0, 255).astype(np.uint8) # (bs,c,h,w)
img = img.transpose(0, 2, 3, 1) # (bs,h,w,c)

for i in range(batch_size):
    fname = f'{round}_{i}.png'
    PIL.Image.fromarray(img[i], 'RGB').save(os.path.join(output_dir, fname))
    np.savez(os.path.join(output_dir, f'{round}_{i}.npz'), w=w_samples[i,0,:].cpu().numpy(), z=z[i,:].cpu().numpy())