moein-shariatnia / Deep-Learning

In-depth tutorials on deep learning. The first one is about image colorization using GANs (Generative Adversarial Nets).
MIT License
152 stars 51 forks source link

How to test the code on custom data? #9

Open ParthKalkar opened 1 year ago

ParthKalkar commented 1 year ago

I went through your codebase and it is excellently explaining the whole process, but I could not find how to test the model for custom images after training.

I was thinking to extend this project and use this model for a webapp. Therefore, I need to know how to test it for custom images?

moein-shariatnia commented 1 year ago

Hi Parth, I hope you're doing great!

Sorry for my late reply; got busy with university for a while and missed your issue. Have you tried the infer.py file in the repo? I show there how to run the model for a single image.

raws84 commented 1 year ago

Hi Moein, It looks like there is some difference in keys in the saved model and model architecture. While using the infer file, I am getting the following error. Any thoughts on how to fix this?

model initialized with norm initialization model initialized with norm initialization

RuntimeError Traceback (most recent call last) in <cell line: 11>() 9 saved_model_dict = torch.load('final_model_weights.pt', map_location=device) 10 model = MainModel() ---> 11 model.load_state_dict(saved_model_dict) 12 13 # Load the black and white image and resize it

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict) 2039 2040 if len(error_msgs) > 0: -> 2041 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 2042 self.class.name, "\n\t".join(error_msgs))) 2043 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for MainModel: Missing key(s) in state_dict: "net_G.model.model.0.weight",

soumyamindfire commented 4 months ago

Hi Moein, It looks like there is some difference in keys in the saved model and model architecture. While using the infer file, I am getting the following error. Any thoughts on how to fix this?

model initialized with norm initialization

model initialized with norm initialization RuntimeError Traceback (most recent call last) in <cell line: 11>() 9 saved_model_dict = torch.load('final_model_weights.pt', map_location=device) 10 model = MainModel() ---> 11 model.load_state_dict(saved_model_dict) 12 13 # Load the black and white image and resize it

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict) 2039 2040 if len(error_msgs) > 0: -> 2041 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 2042 self.class.name, "\n\t".join(error_msgs))) 2043 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for MainModel: Missing key(s) in state_dict: "net_G.model.model.0.weight",

Load the ResNet weights as well. You have to train and save the weights.