Open nateanl opened 3 years ago
cc @parmeet So, I gave some thoughts about this and adding the freezing feature makes sense. The question I have is where is the best place to put the feature.Â
There are three possible places I can think of.
torchaudio.models.freeze_params_(model)
.
  Pros: simple and most modular and most reusable
  Cons: if some model need special treatment, it is hard to implement.Model.freeze()
. (similar to SpeechBrain example above)
  Pros: specialization is easy.
  Cons: Each model has to implement the method. (Putting the mbaseethod in a base class or Mixing would make this easy while it's still possible to perform specialization)get_model(freeze=true)
(torchtext #1428)
  Pros: the freezing pattern is more closer to user use case, so if a bundle has to be picky about the way how model should be frozen for the expected use case, this is the right place.
   Cons: if this is the only place that implements freezing functionality, then it's not reusable.Now I think it is possible to provide the combination of above while reusing the core implementation as utility function.
freeze(model)
<- model.freeze(self)
bundle.get_model(freeze=True)
if (and only if) bundle-wise specialization is required.Use cases are like the following.
# bundle haves freeze specialization
model = bundle.get_model(freeze=True)
or
# bundle does not have freeze specialization
model = bundle.get_model()
model.freeze()
model = torchaudio.models.SomeModel()
model.freeze()
model = MyModel()
torchaudio.models.freeze_params_(model)
What do you think?
model = bundle.get_model(freeze=True)
makes sense for me. The freeze operation only happens in pre-trained cases, while model.freeze()
can be used when the model is not trained at all, which may potentially lead to an improper usage.
In principle, conditions when freeze option is meaningful are: 1) The part of model is already pre-trained 2) The pre-trained part is used in conjunction with additional layers that require some training.
Making freeze as part of builder APIs (get_model, explicit factory functions etc) would allow: 1) library maintainers to log proper warning (or error) messages when the usage is not appropriate (for instance when user set freeze as True when instantiating model without pre-trained weights for eg here) 2) customize behavior (specialization). For instance, when builder APIs provide composite models, it is often not obvious how and which part to freeze after the model is instantiated. Whereas, having the freeze option during construction, would allow the library to define an out-of-the-box implementation for at least the default behavior (for eg: freezing encoder and leaving the task specific head untouched).
Now there are more complex cases, for instance freezing only part of the pre-trained model weights (up-to certain layers), which is certainly more complex behavior and could benefit by having a specialized model API method (option 2).
Making freeze as a general model API requires more thoughtfulness IMO. For eg: It is not immediately clear when the user calls model.freeze() for models that are composite (encoder+task), should we only freeze the encoder, or should we freeze both encoder+task (which of-course is not meaningful). Also agree with @nateanl comments above which are along these lines.
So in summary, regarding various options stated by @mthrok : Option 1 is certainly safe as a utility, we could just freeze all model parameters without needing to worry about the underlying semantics or usage. The user is completely responsible for the consequences.
Option 2 requires some thoughtfulness to define and implement the correct behavior as it becomes part of the Model API itself (compositionality could bring its own nuances). But this could come handy for dealing with more complex freezing scenarios (like freezing only part of the pre-trained model).
Option 3 is safe at the moment. It provides a clear message to the user what to expect. And as stated already, more close to user use-case leaving less room for improper usage :).
In terms of the implementation details, I'm thinking about the behaviour of freeze
, besides setting requires_grad=False
, we need to be careful about model.eval()
as it changes the behaviour of Dropout and BatchNormalization.
If we use the model as a feature extractor, the gradient of the model won't be passed to other models, then it's safe to set the mode to eval
.
If we jointly train the model with some front-end models, we can still turn off the Dropout to make full use of the model, but the batchnormalization layer should be set to train
model as it affects the training performance.
@parmeet @nateanl Thanks for the input. Let's proceed with option 3.
In terms of the implementation details, I'm thinking about the behaviour of
freeze
, besides settingrequires_grad=False
, we need to be careful aboutmodel.eval()
as it changes the behaviour of Dropout and BatchNormalization.If we use the model as a feature extractor, the gradient of the model won't be passed to other models, then it's safe to set the mode to
eval
.If we jointly train the model with some front-end models, we can still turn off the Dropout to make full use of the model, but the batchnormalization layer should be set to
train
model as it affects the training performance.
I guess both dropout and batch-normalization could be useful when training the task head while the encoder is freezed. So in case when only the task head is fine-tuned it probably won't hurt to leave the encoder in train mode. @abhinavarora , @hudeven what are your thoughts on this?
I agree with @parmeet. Dropout can always be used during training regardless of the fact that we freeze the encoder or not.
Thanks @abhinavarora and @parmeet. In the second case, if the pre-trained model is after the model we want to train, then using Dropout should be fine, we can just freeze the weights.
(I guess it's model-specific) I saw the Wav2Vec model in torchaudio has different behaviours between training and eval, see https://github.com/pytorch/audio/blob/main/torchaudio/models/wav2vec2/components.py#L396
for layer in self.layers:
if not (self.training and torch.rand(1).item() <= self.layer_drop):
x = layer(x, attention_mask)
If the model is used as a feature extractor, users can dump the features to a pickle file for future usage. We should set the model to eval
mode to use all layers to get the feature representations.
Shall we consider both cases in the freeze option?
It's likely I am missing the context here, so just some clarifying questions:
1) Typically my understanding so far is we freeze the lower layers (closer to input) and train higher level layers. When the pre-trained model is after the model we are training, does this mean the higher level layers are fixed (and are pre-trained) and loss is calculated using the output of pre-trained model?
2) When the model is used as feature extractor, is it still in context of some training, or is it a standalone usage of the model in inference mode?
does this mean the higher level layers are fixed (and are pre-trained) and loss is calculated using the output of pre-trained model?
Exactly, we use the gradient of pre-trained model to back propagate to lower level layers (untrained).
If we use the freeze design option proposed by @mthrok, we can define specific freeze option for each model, then we can solve them case by case.
@parmeet To supplement the above point, @nateanl is working on a speech enhancement (noise reduction) system, that is attached in front of a pre-trained speech recognition model.
@nateanl I feel that if you are using pre-trained speech recognitions model as downstream of speech enhancement module, I feel like speech recognition model serves as loss function, and it should not have random behavior. (of course, the best way to find out is to run experiments though)
Thanks @nateanl and @mthrok for providing additional context and clarification.
I think I wouldn't worry much as to where to place in the freeze options, and depending on the use case one option might suit better than other and vice-versa. What I would be careful though is to ensure the semantic meaning of freeze is preserved and be decoupled from model behavior.
Technically speaking, freezing imply the weights of model will not change during training. If we decide to implicitly define the model behavior (like putting model in eval mode and disabling layer_drop etc) be also controlled by freeze parameter, then it would mean that we would not be able to satisfy the use-case when user want to freeze weights but keep the model behavior stochastic (for whatever good reasons). On the other hand, if this is indeed the right expected behavior in every possible usage of Wav2Vec model (i.e freeze would also make the whole model behave in deterministic manner) then I think it's Ok to do so, and we should clarify this message in the documentation etc.
If we decide to implicitly define the model behavior (like putting model in eval mode and disabling layer_drop etc) be also controlled by freeze parameter, then it would mean that we would not be able to satisfy the use-case when user want to freeze weights but keep the model behavior stochastic (for whatever good reasons). On the other hand, if this is indeed the right expected behavior in every possible usage of Wav2Vec model (i.e freeze would also make the whole model behave in deterministic manner) then I think it's Ok to do so, and we should clarify this message in the documentation etc.
@parmeet Thanks, this is very on-spot. I agree that we can separate the concerns of the model parameters and model behaviors. And I think it's better design decision for the library to support the both cases, especially since there is no universally agreed convention on model behavior when freezing.
In the FastAI library usually, the practice is to get all the parameters with requires_grad = False
(Like @nateanl mentioned without model.eval()
). Usually, some parameters are passed (like layer number) to freeze/unfreeze a certain layer of the model.
for para in net.parameters():
para.requires_grad = False
https://github.com/fastai/fastai/blob/master/fastai/optimizer.py#L30
🚀 The feature
In some research cases, the Wav2Vec2 or HuBERT is expected to be frozen (i.e. make
reuqires_grad=False
for all params).It'll be good to add an argument to set the model to frozen state so that users don't need to set
requires_grad=False
by themselves.Motivation, pitch
SpeechBrain has similar implementation by adding
freeze
argument when initilizing the model. https://github.com/speechbrain/speechbrain/blob/f1f421b3bb58dabc75d67c3fd5f6e3359943b927/speechbrain/lobes/models/fairseq_wav2vec.py#L78cc @mthrok