anderzzz / monkey_caput

Custom PyTorch model (VGG-16 Auto-Encoder) and custom criterion (Local Aggregation) for image clustering. The repo contains elaborated creation of fungi image data using factory method.
38 stars 17 forks source link

Where is the starting point for custom images? #5

Open Weilin37 opened 2 years ago

Weilin37 commented 2 years ago

Hi,

I'd like to use this for unsupervised image clustering. Is it as simple as just swapping the example images with my own images, keeping the rest of the code the same? If so, how do I run this?

anderzzz commented 2 years ago

Generally, the files with names ending on _runs.py are the ones I execute. See for example ae_runs.py. However, you should rewrite these or implement it as a Jupyter Notebook.

The typical training loop is found in the files with names ending on _learner.py. See for example ae_learner.py. If you use your own dataset of images you have to create your own PyTorch Dataset. You can see on this line in the parent learner class where I create the fungi image dataset. That's where you have to insert your own code to create a PyTorch Dataset. Exactly how you create the DataSet is mostly up to you, say if you include data augmentation, cropping etc. But try to use 224x224 dimension of image otherwise the convolutions may not add up. The Dataset generator class I used the most was this one.

The steps, high-level, to do the unsupervised learning are: (1) Train the auto-encoder on your dataset. I recommend that you start with a small dataset. It can be tricky to get an Auto-Encoder this deep to converge, so you have to explore your training parameters. It also depends on how complex your images are. This is the step I encoded in ae_runs.py. (2) Once you have a decent model for the Auto-Encoder, you initialize LALearner. As for the Auto-Encoder, you see how that's done in la_runs.py. Note that you must load the model for LALearner as you obtained from the training of the Auto-Encoder, see for example this line. Then you run the training. The LALearner inherits from _Learner, as the AELearner did, where you have defined the PyTorch dataset, as described above. And again, the LA method contains a number of hyper-parameters that require some work to tune to your specific problem.

As I wrote in my blog post on this, unsupervised learning is not nearly as developed as the supervised learning methods. That means the hyper-parameters are not nearly as well established. That makes the work interesting, but also a lot harder. I say this because I know it can be hard to make it work for your own dataset. You are free to use my code and explore, but it is wise to have some prior familiarity with PyTorch and how to train deep networks. Convergence can be hard to reach.

I hope this helps and good luck.

Ryanunana commented 3 weeks ago

Thank you for your great code. I'm a freshman of DL. Could you please teach me how to create my own datasets because I've already tried so many times. But I still can't solve it. I'd appreciate it if you can give me some advise.

anderzzz commented 2 weeks ago

Hello. Sorry for late reply.

The dataset is a PyTorch Dataset. https://pytorch.org/docs/stable/data.html

The image data set enter the code here: https://github.com/anderzzz/monkey_caput/blob/d4789446636953377d1c709207ce15421897eeb9/_learner.py#L104

The custom datasets I create are subclasses of the PyTorch Dataset. The central features of a custom dataset class is that there is a “length” property and a “get item” property. Those are the ones you must write to read from your image collection.

There is a very good tutorial for that, see https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

You can see the core part of my dataset here: https://github.com/anderzzz/monkey_caput/blob/master/fungidata.py#L268

The reason I have so much more stuff in the Dataset than minimally required is because when I do training, it is often helpful to augment the dataset with transformed images. But if you have problems to get started, you should follow the tutorial above. Once done, you can consider how the integrate augmented datasets.

I hope this helps. The PyTorch Dataset class helps PyTorch to read your data no matter how you specifically store and organize it. But otherwise, it is a method of returning an image given an index, and a label if you do supervised learning.

If you still cannot make it work, you have to tell me more about how you store the image files. What I did was that I created a CSV file that provided a table of content for my image collection, see this line: https://github.com/anderzzz/monkey_caput/blob/master/fungidata.py#L286