cedrickchee / capsule-net-pytorch

[NO MAINTENANCE INTENDED] A PyTorch implementation of CapsNet architecture in the NIPS 2017 paper "Dynamic Routing Between Capsules".
Other
168 stars 50 forks source link

RGB 256*256 image #4

Open deep0learning opened 6 years ago

deep0learning commented 6 years ago

Can you tell us how can we use your code to classify RGB images.

Our dataset is like that:

class1: 0001.jpg 0002.jpg Class2: 001.jpg 002.jpg

cedrickchee commented 6 years ago

Sorry for the terribly late response. If I understand your intention correctly, you are trying to do inference using the model. Unfortunately, the code has no proper support for inference mode yet. It's actually not hard to make this work based on the existing codebase as we are using PyTorch.

Steps on how to achieve this:

  1. Create a new function in main.py and named it infer. Make this function accept a new input parameter, data for passing in images from your dataset (0001.jpg, 0002.jpg, etc.)
  2. Convert the pixel data from the image file into multi-dimensional array/vector.
  3. Write the code to load the trained or pre-trained weights. In PyTorch, it is easy to do this. Read the PyTorch docs here to find out on how to do that: http://pytorch.org/docs/master/notes/serialization.html#
  4. Copy the following code from the test function and make some slight modification to support the passed in image data instead of data from the dataset.
    # arr_image: ndarray type
    output = model(arr_image) # output predictions
  5. Extend argparse.ArgumentParser by adding a new option that accept image file paths that will be pass in from your terminal/CLI.
  6. Profit!

I am planning to add that in soon in the codebase. Please read the README for the things I have planned to do soon or in the near future.

Let us know if you have any queries.