rdroste / unisal

Unified Image and Video Saliency Modeling (ECCV 2020)
https://arxiv.org/abs/2003.05477
Apache License 2.0
131 stars 37 forks source link

finetuning on a different dataset #4

Closed prashnani closed 1 year ago

prashnani commented 4 years ago

Hi! thanks for releasing the code for unisal!

Is it possible to provide a minimal set of scripts and instructions to finetune the unisal model on a different dataset other than the ones listed in the respository? I notice that data.py, train.py etc. come with various commands to switch across the datasets you used to train unisal. Would be great to get some help with training on just one new dataset.

Thanks, Ekta

rdroste commented 4 years ago

Hi Ekta, thanks for your interest in our work. A minimal example for fine-tuning the model is a good idea, I'll try to find some time soon the upload one.

However, one difficulty with a general fine-tuning example might be that the optimal fine-tuning method (learning rate, learning rate schedule, batch size, freezing different parts of the network, etc., etc.) really depends on the target dataset. Therefore you could manually load the UNISAL model and plug it into your own training script.

To load the pretrained model you can run something like:

import unisal
model = unisal.model.UNISAL()
model.load_best_weights('unisal/training_runs/pretrained_unisal')

If you want to load the model for one of the training datasets only, you could also run (untested):

my_source = <insert whichever dataset matches your data most closely from ('DHF1K', 'Hollywood', 'UCFSports', 'SALICON')>
model = unisal.model.UNISAL(sources=(my_source,))
model.load_state_dict(torch.load('unisal/training_runs/pretrained_unisal/weights_best.pth'), strict=False)

(Instead of using strict=False, which can fail silently, you could also remove the weights with keys 'rnn', 'post_rnn' and keys containing 'DHF1K', 'Hollywood' or 'UCFSports' from the state dict)

If you want to use the model for static data only, you can reduce the model size by loading it without the GRU RNN by running (untested):

my_source = 'SALICON'
model = unisal.model.UNISAL(sources=(my_source,))
model.load_state_dict(torch.load('unisal/training_runs/pretrained_unisal/weights_best.pth'), strict=False)

In your training code you can then call the model with

# ... your code here
my_source = <one of ('DHF1K', 'Hollywood', 'UCFSports', 'SALICON')>
static= <True or False>
prediction = model(training_batch, source=my_source, static=static)
prashnani commented 4 years ago

thanks for your response @rdroste ! Will wait for your minimal example. :+1: Makes sense that with a new dataset, there would be work involved with hyper-parameter tuning.

For plugging unisal into my own training script: It would be great to know which of the components of model.py are needed when training for just one new (not present in the list of datasets in your method) dataset. As of now it seems that model.py contains domain-specific normalization, multiple sources, etc. - these components may / may not be needed when there is only one (new) dataset given for training?

rdroste commented 4 years ago

Exactly, you wouldn't need to keep all the domain-specific parameters when you fine-tune the model. When initializing the UNISAL model class I would set the sources=(my_source,) where my_source is the dataset from ('DHF1K', 'Hollywood', 'UCFSports', 'SALICON') that matches your target dataset most closely. It might be worth trying out all of them. For a static target dataset the best choice would be sources=('SALICON',). For video data it's difficult to say a priory but sources=('DHF1K',) might be a good default option since DHF1K is the most varied of the video datasets. Afterwards, call the model forward with pred = model(x, source='my_source'). Hope that makes sense.

prashnani commented 4 years ago

thanks @rdroste ! let me give this a try.