google-research / t5x

Apache License 2.0
2.62k stars 297 forks source link

target dict keys and state dict keys do not match while trying to restore checkpoint #1196

Open pjlintw opened 1 year ago

pjlintw commented 1 year ago

trying to fine-tune T5-small v1.1 on single GPU using on the sample script (singlenode_ft_frompile.sh)

Not completely know how the script loads the checkpoint but I've set the GOOGLE_CLOUD_BUCKET_NAME and it seems the script could find the checkpoint (gs://t5-data/pretrained_models/t5x/t5_1_1_small) on the google cloud storage.

It appears that there is an issue where the target dictionary cannot load the weights from the state dictionary, possibly due to an additional key '{'1'}' in the target dictionary.

I get the following error:

File "/nethome/pjlin/anaconda3/envs/prompt-tuning/lib/python3.8/site-packages/gin/utils.py", line 42, in augment_exception_message_and_reraise raise proxy.with_traceback(exception.traceback) from None File "/nethome/pjlin/anaconda3/envs/prompt-tuning/lib/python3.8/site-packages/gin/config.py", line 1582, in gin_wrapper return fn(*new_args, **new_kwargs) File "/nethome/pjlin/pythonProjects/prompt-tuning/thirt_party/t5x//t5x/train.py", line 409, in train train_state = checkpoint_manager.restore( File "/nethome/pjlin/anaconda3/envs/prompt-tuning/lib/python3.8/site-packages/t5x/utils.py", line 562, in restore self._checkpointer.restore( File "/nethome/pjlin/anaconda3/envs/prompt-tuning/lib/python3.8/site-packages/t5x/utils.py", line 328, in restore return self._restore_checkpointer.restore( File "/nethome/pjlin/anaconda3/envs/prompt-tuning/lib/python3.8/site-packages/t5x/checkpoints.py", line 1119, in restore written_state_dict = serialization.from_state_dict(dummy_written_state_dict, File "/nethome/pjlin/anaconda3/envs/prompt-tuning/lib/python3.8/site-packages/flax/serialization.py", line 93, in from_state_dict return ty_from_state_dict(target, state) File "/nethome/pjlin/anaconda3/envs/prompt-tuning/lib/python3.8/site-packages/flax/serialization.py", line 170, in _restore_dict return {key: from_state_dict(value, states[str(key)], name=str(key)) File "/nethome/pjlin/anaconda3/envs/prompt-tuning/lib/python3.8/site-packages/flax/serialization.py", line 170, in return {key: from_state_dict(value, states[str(key)], name=str(key)) File "/nethome/pjlin/anaconda3/envs/prompt-tuning/lib/python3.8/site-packages/flax/serialization.py", line 93, in from_state_dict return ty_from_state_dict(target, state) File "/nethome/pjlin/anaconda3/envs/prompt-tuning/lib/python3.8/site-packages/flax/serialization.py", line 170, in _restore_dict return {key: from_state_dict(value, states[str(key)], name=str(key)) File "/nethome/pjlin/anaconda3/envs/prompt-tuning/lib/python3.8/site-packages/flax/serialization.py", line 170, in return {key: from_state_dict(value, states[str(key)], name=str(key)) File "/nethome/pjlin/anaconda3/envs/prompt-tuning/lib/python3.8/site-packages/flax/serialization.py", line 93, in from_state_dict return ty_from_state_dict(target, state) File "/nethome/pjlin/anaconda3/envs/prompt-tuning/lib/python3.8/site-packages/flax/serialization.py", line 166, in _restore_dict raise ValueError('The target dict keys and state dict keys do not match,' ValueError: The target dict keys and state dict keys do not match, target dict contains keys {'1'} which are not present in state dict at path ./state/param_states In call to configurable 'train' (<function train at 0x7f737de4eca0>)

and set-up for the command:

. t5x/contrib/gpu/scripts_gpu/singlenode_ft_frompile.sh \
    mnli2 \
    small \
    float32 \
    1 \
    8 \
    /nethome/pjlin/pythonProjects/prompt-tuning/thirt_party/t5x/t5x/contrib/gpu/scripts_gpu/outputs \
    gs://t5-data/pretrained_models/t5x/t5_1_1_small \
    0 \
    false

Also, still got the same error as trying the base size of the model.

kiriharulxh commented 1 year ago

Same problem encountered. Have you solved it? @pjlintw

Jeevesh8 commented 1 year ago

Same issue. Did anyone resolve? It seems that the variables ckpt_state_dict and dummy_written_state_dict have different structures in the file t5x/checkpoints.py (line 1119) :

ckpt_state_dict:

state
    param_states
        decoder
            decoder_norm
                scale
                    m
                        (1,)
                    v
                        (768,)
                    v_col
                        (1,)
                    v_row
                        (1,)
            layers_0
                encoder_decoder_attention
                    key
                        kernel
                            m
                                (1,)
                            v
                                (1,)
                            v_col
                                (768,)
                            v_row
                                (768,)
                    out
                        kernel
                            m
                                (1,)
.
.
.
.
                     encoder
            encoder_norm
                scale
                    m
                        (1,)
                    v
                        (768,)
                    v_col
                        (1,)
                    v_row
                        (1,)
            layers_0
                attention
                    key
                        kernel
                            m
                                (1,)
                            v
                                (1,)
                            v_col
                                (768,)
                            v_row
                                (768,)
                    out
                        kernel
                            m
                                (1,)
                            v
                                (1,)
                            v_col
                                (768,)
                            v_row
                                (768,)
target
    decoder
        decoder_norm
            scale
                (768,)
        layers_0
            encoder_decoder_attention
                key
                    kernel
                        driver

And for dummy_written_state_dict:

state
    param_states
        1
            0
                count
                mu
                    decoder
                        decoder_norm
                            scale
                        layers_0
                            encoder_decoder_attention
                                key
                                    kernel
                                out
                                    kernel
.
.
.
.
                nu
                    decoder
                        decoder_norm
                            scale
                        layers_0
                            encoder_decoder_attention
                                key
                                    kernel
                                out
                                    kernel
                                query
                                    kernel
                                value
                                    kernel
                            mlp
                                wi
                                    kernel
                                wo
                                    kernel
                            pre_cross_attention_layer_norm
                                scale
                            pre_mlp_layer_norm
                                scale
                            pre_self_attention_layer_norm
                                scale
                            self_attention
                                key
                                    kernel
target
    decoder
        decoder_norm
            scale
        layers_0
            encoder_decoder_attention
                key
                    kernel
                out
                    kernel
                query
                    kernel
.
.
.
.

@pjlintw @kiriharulxh