alexanderkroner / saliency

Contextual Encoder-Decoder Network for Visual Saliency Prediction [Neural Networks 2020]
MIT License
181 stars 49 forks source link

Training on custom (new) dataset. #29

Open rakehsaleem opened 11 months ago

rakehsaleem commented 11 months ago

Hi @alexanderkroner, thank you for such a nice piece of code and thorough documentation for the repository.

How can I train a model on a custom/new dataset? My dataset contains eye-tracking videos of 10 participants, ranging from 3 to 4 mins. I have the gaze information, fixations, and timestamps. From what I understand, I have to first prepare the dataset as SALICON dataset that is accepted and read by the code and then train the model. However, I have some confusion and would like to ask questions that will help me in the right direction.

  1. Dataset preparation

    • convert videos into frames such that individual frame contains the gaze data of one participant and its time stamp and fixation information.
    • generate a map of the individual images and feed it as a ground truth.
  2. Model training

    • first train on the SALICON dataset to pre-train it.
    • use the pre-train model to fine-tune my dataset.

    To fine-tune the model, what parameters must I change, or how would I incorporate the SILICON pre-trained model in the command to read the pre-trained model and then later use it on my dataset?

    Thank you so much for your time.

alexanderkroner commented 8 months ago

Hi @rakehsaleem, apologies for the late reply!

Yes, you can easily train a model on a custom dataset. Here is how:

  1. Your videos need to be saved as individual frames together with the corresponding fixation data, using the same file names. The fixations should come in the form of saliency maps, so after applying a (Gaussian) blur to the fixation locations. Stimuli and saliency maps must then be moved to the folders ./data/custom/stimuli/ and ./data/custom/saliency/ respectively.
  2. Make sure to download the model with weights pre-trained on SALICON before fine-tuning it on your custom dataset. You can trigger the automatic download by running the command python main.py test -d salicon or download the model here and save the file under ./weights/.
  3. Add the entry "image_size_custom" with your preferred image size (height-by-width, divisible by 8) here: https://github.com/alexanderkroner/saliency/blob/c64f87b90ecdf237e93b12a145a22d0903b07d48/config.py#L23-L31
  4. Add the entry "custom" here: https://github.com/alexanderkroner/saliency/blob/c64f87b90ecdf237e93b12a145a22d0903b07d48/main.py#L220-L221
  5. Add the entry "custom" here: https://github.com/alexanderkroner/saliency/blob/c64f87b90ecdf237e93b12a145a22d0903b07d48/model.py#L405-L406
  6. Add the class CUSTOM in data.py and adjust the number of training/validation examples.

    class CUSTOM:
        n_train = 100  # adjust to your number of training examples
        n_valid = 50   # adjust to your number of validation examples
    
        def __init__(self, data_path):
            self._target_size = config.DIMS["image_size_custom"]
    
            self._dir_stimuli = data_path + "stimuli"
            self._dir_saliency = data_path + "saliency"
    
        def load_data(self):
            list_x = _get_file_list(self._dir_stimuli)
            list_y = _get_file_list(self._dir_saliency)
    
            _check_consistency(zip(list_x, list_y), self.n_train + self.n_valid)
    
            indices = _get_random_indices(self.n_train + self.n_valid)
            excerpt = indices[:self.n_train]
    
            train_list_x = [list_x[idx] for idx in excerpt]
            train_list_y = [list_y[idx] for idx in excerpt]
    
            train_set = _fetch_dataset((train_list_x, train_list_y),
                                       self._target_size, True)
    
            excerpt = indices[self.n_train:]
    
            valid_list_x = [list_x[idx] for idx in excerpt]
            valid_list_y = [list_y[idx] for idx in excerpt]
    
            valid_set = _fetch_dataset((valid_list_x, valid_list_y),
                                       self._target_size, False)
    
            return (train_set, valid_set)

Finally, you should be able to train a model on your custom dataset, starting from the SALICON weights, by running the command python main.py train -d custom.

rakehsaleem commented 8 months ago

Thank you, @alexanderkroner, for sharing your detailed response. I will try this code and see what I get and how the training part goes! I will post my results here soon.,