Closed prashnani closed 1 year ago
Hi Ekta, thanks for your interest in our work. A minimal example for fine-tuning the model is a good idea, I'll try to find some time soon the upload one.
However, one difficulty with a general fine-tuning example might be that the optimal fine-tuning method (learning rate, learning rate schedule, batch size, freezing different parts of the network, etc., etc.) really depends on the target dataset. Therefore you could manually load the UNISAL model and plug it into your own training script.
To load the pretrained model you can run something like:
import unisal
model = unisal.model.UNISAL()
model.load_best_weights('unisal/training_runs/pretrained_unisal')
If you want to load the model for one of the training datasets only, you could also run (untested):
my_source = <insert whichever dataset matches your data most closely from ('DHF1K', 'Hollywood', 'UCFSports', 'SALICON')>
model = unisal.model.UNISAL(sources=(my_source,))
model.load_state_dict(torch.load('unisal/training_runs/pretrained_unisal/weights_best.pth'), strict=False)
(Instead of using strict=False
, which can fail silently, you could also remove the weights with keys 'rnn', 'post_rnn' and keys containing 'DHF1K', 'Hollywood' or 'UCFSports' from the state dict)
If you want to use the model for static data only, you can reduce the model size by loading it without the GRU RNN by running (untested):
my_source = 'SALICON'
model = unisal.model.UNISAL(sources=(my_source,))
model.load_state_dict(torch.load('unisal/training_runs/pretrained_unisal/weights_best.pth'), strict=False)
In your training code you can then call the model with
# ... your code here
my_source = <one of ('DHF1K', 'Hollywood', 'UCFSports', 'SALICON')>
static= <True or False>
prediction = model(training_batch, source=my_source, static=static)
thanks for your response @rdroste ! Will wait for your minimal example. :+1: Makes sense that with a new dataset, there would be work involved with hyper-parameter tuning.
For plugging unisal into my own training script: It would be great to know which of the components of model.py are needed when training for just one new (not present in the list of datasets in your method) dataset. As of now it seems that model.py contains domain-specific normalization, multiple sources, etc. - these components may / may not be needed when there is only one (new) dataset given for training?
Exactly, you wouldn't need to keep all the domain-specific parameters when you fine-tune the model. When initializing the UNISAL
model class I would set the sources=(my_source,)
where my_source
is the dataset from ('DHF1K', 'Hollywood', 'UCFSports', 'SALICON') that matches your target dataset most closely. It might be worth trying out all of them. For a static target dataset the best choice would be sources=('SALICON',)
. For video data it's difficult to say a priory but sources=('DHF1K',)
might be a good default option since DHF1K is the most varied of the video datasets. Afterwards, call the model forward with pred = model(x, source='my_source')
. Hope that makes sense.
thanks @rdroste ! let me give this a try.
Hi! thanks for releasing the code for unisal!
Is it possible to provide a minimal set of scripts and instructions to finetune the unisal model on a different dataset other than the ones listed in the respository? I notice that data.py, train.py etc. come with various commands to switch across the datasets you used to train unisal. Would be great to get some help with training on just one new dataset.
Thanks, Ekta