Closed stephenyan1231 closed 6 months ago
This pull request was exported from Phabricator. Differential Revision: D52152706
This pull request was exported from Phabricator. Differential Revision: D52152706
This pull request was exported from Phabricator. Differential Revision: D52152706
This pull request was exported from Phabricator. Differential Revision: D52152706
This pull request has been merged in facebookresearch/d2go@c2256758202eb51ae8f21200f58dcbb70ca96690.
Summary: When we build a QAT model using FX graph mode API prepare_qat_fx and convert_fx, they will run symbolic tracing following module.forward().
In certain cases, such as a module takes constant tensor input, the symbolic tracing will add new tensor attributes with name prefix _tensor_constant (https://fburl.com/code/msc4ch4o), which becomes new keys in the QAT model state dict.
In current implementation of _setup_non_qat_to_qat_state_dict_map, it asserts # of keys in the state dict of original- and QAT model should be the same.
Thus, we extend qat_state_dict_keys_to_ignore method by adding an argument, which allows to ignore specified state dict keys in the QAT model.
Differential Revision: D52152706