junyanz / pytorch-CycleGAN-and-pix2pix

Image-to-Image Translation in PyTorch
Other
23.06k stars 6.31k forks source link

Not able to load a pre trained weight that I have used for training of custom data. #1621

Open AGRocky opened 11 months ago

AGRocky commented 11 months ago

Hey Guys, Please hear out. I have trained a model and have a pre trained weight but not able to load that model to test it on different set of images which i have prepared. please help me out on this one. It would be deeply appretiated.

thanks in advance

JustinasLekavicius commented 10 months ago

Was it CycleGAN or Pix2pix model?

To test it on a different data set, you could use this command as an example:

python test.py --dataroot directory, e.g. (/content/data/A) --name model_name --model test --netD n_layers --n_layers_D 3 --netG=unet_256 --norm=instance --direction AtoB --dataset_mode single --preprocess none --input_nc 1 --output_nc 3 --ndf 64 --ngf 64 --num_test 512

Make sure to replace the parameters with the same ones you used for training of the model (netD, n_layers, netG, ndf, ngf, etc.)

AGRocky commented 10 months ago

Hey @JustinasLekavicius thank you for replying to my issue. I am using CycleGAN for training purpose is to denoise the image. However I have the pre trained weight which works well when it is used with the python command line code "python test.py --dataroot directory, e.g. (/content/data/A) --name model_name --model test..."

But when I try the same thing by creating a class to load the model and give image as input and get the output as denoised image, I am unable to do it. However if I try to load the model the prediction or the testing output image which is denoised image isn't getting generated accurately but it's happening in with the above python command code. please help me with this

heartily thank you in advance

AGRocky commented 10 months ago

import torch from models.networks import define_G from PIL import Image from torchvision import transforms from IPython.display import display

Define the generator model

generator = define_G(input_nc=3, output_nc=3, ngf=64, netG='resnet_9blocks', norm='instance', use_dropout=False,init_type='normal',init_gain=0.02)

Load the pre-trained weights from a saved checkpoint

generator_checkpoint_path = 'latest_net_G.pth' checkpoint = torch.load(generator_checkpoint_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))

Load the generator state_dict

generator.load_state_dict(checkpoint, strict=False)

Set the model to evaluation mode (important if using dropout during training)

generator.eval()

Load your input image

input_image_path = 'nt6.jpg' input_image = Image.open(input_image_path).convert('RGB')

Resize the input image to the expected size

input_image = input_image.resize((512, 512))

Convert the input image to a PyTorch tensor

input_tensor = transforms.ToTensor()(input_image).unsqueeze(0) # Add batch dimension

Move the input tensor to the GPU if available

if torch.cuda.is_available(): input_tensor = input_tensor.to('cuda')

Move the generator to the same device as the input tensor

generator = generator.to(input_tensor.device)

Generate the output image

with torch.no_grad(): output_tensor = generator(input_tensor)

Move the output tensor to the CPU if necessary

output_tensor = output_tensor.cpu()

Convert the output tensor to a PIL image

output_image = transforms.ToPILImage()(output_tensor.squeeze(0))

Display the generated image

display(output_image)

Save the generated image

output_image.save('generated_image.jpg')

this is my code

ystoneman commented 9 months ago

Hi AGRocky,

To help troubleshoot your CycleGAN model issue, could you provide:

  1. Versions: Exact versions of Python, PyTorch, and other libraries used.
  2. System Specs: Your GPU model and overall system configuration.
  3. Error Details: Any specific error messages or warnings during model loading.

These details will help in accurately replicating the issue and providing a solution.

Thanks!

zhaoyong-li commented 2 months ago

Hi, have you solved the problem now? I also loaded the model as per this method, but my output is a black image, I don't know what's the reason.