Expected behavior:
There should be a boolean flag (amp=True/False) in TransformerModelArguments and SetFitModelArguments. Moreover, the flag should be added as well for KimCNNClassifier whose model arguments are hopefully adjusted to match the former two classes, otherwise this flag should be in the constructor as a temporary solution.
Motivation
This could provide a nice speedup for use cases where mixed precision is good enough.
Additional comments
Note: I tried this once in a hacky way and stopped because I got nan values in the loss (and I did not really need this feature at that time). This was likely a Pytorch bug and should be kept in mind here:
https://github.com/pytorch/pytorch/releases/tag/v2.0.1
Feature description
Expected behavior: There should be a boolean flag (amp=True/False) in TransformerModelArguments and SetFitModelArguments. Moreover, the flag should be added as well for KimCNNClassifier whose model arguments are hopefully adjusted to match the former two classes, otherwise this flag should be in the constructor as a temporary solution.
Motivation
This could provide a nice speedup for use cases where mixed precision is good enough.
Additional comments
Note: I tried this once in a hacky way and stopped because I got nan values in the loss (and I did not really need this feature at that time). This was likely a Pytorch bug and should be kept in mind here: https://github.com/pytorch/pytorch/releases/tag/v2.0.1