google-research / language

Shared repository for open-sourced projects from the Google AI Language team.
https://ai.google/research/teams/language/
Apache License 2.0
1.58k stars 341 forks source link

[NQG] AssertionError: Found 392 Python objects that were not bound to checkpointed values #151

Open KaiserWhoLearns opened 1 year ago

KaiserWhoLearns commented 1 year ago

Hi,

Thanks for sharing the code. I was trying to replicate the NQG results on GeoQuery (https://github.com/google-research/language/tree/master/language/compgen/nqg). After finished up with QCFG induction step and generating training data steps, I downloaded BERT-uncased model checkpoint from this link: https://github.com/tensorflow/models/blob/master/official/nlp/docs/pretrained_models.md because the original link in repository is broken.

I used the command below for training.

python model/parser/training/train_model.py \
  --input=${TF_EXAMPLES} \
  --config=${CONFIG} \
  --model_dir=${MODEL_DIR} \
  --bert_dir=${BERT_DIR} 

However, I am receiving error below, and the training for NQG could not be proceeded. Could you please point me to the exact model checkpoint that you were using in NQG and what could be the reason here?

Traceback (most recent call last):
  File "/private/home/kaisersun/CompGenComparision/baseline_replication/TMCD/model/parser/training/train_model.py", line 139, in <module>
    app.run(main)
  File "/private/home/kaisersun/.conda/envs/gen/lib/python3.9/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/private/home/kaisersun/.conda/envs/gen/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/private/home/kaisersun/CompGenComparision/baseline_replication/TMCD/model/parser/training/train_model.py", line 135, in main
    train_model(strategy)
  File "/private/home/kaisersun/CompGenComparision/baseline_replication/TMCD/model/parser/training/train_model.py", line 99, in train_model
    status.assert_existing_objects_matched()
  File "/private/home/kaisersun/.conda/envs/gen/lib/python3.9/site-packages/tensorflow/python/training/tracking/util.py", line 846, in assert_existing_objects_matched
    raise AssertionError(
AssertionError: Found 392 Python objects that were not bound to checkpointed values, likely due to changes in the Python program. Showing 10 of 392 unmatched objects: [<tf.Variable 'transformer/layer_7/self_attention/attention_output/bias:0' shape=(1024,) dtype=float32, numpy=array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)>, <tf.Variable 'transformer/layer_2/output/bias:0' shape=(1024,) dtype=float32, numpy=array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)>, <tf.Variable 'transformer/layer_1/output_layer_norm/beta:0' shape=(1024,) dtype=float32, numpy=array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)>, <tf.Variable 'transformer/layer_9/self_attention/attention_output/bias:0' shape=(1024,) dtype=float32, numpy=array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)>, <tf.Variable 'transformer/layer_15/self_attention/query/kernel:0' shape=(1024, 16, 64) dtype=float32, numpy=
array([[[-5.33815415e-04,  9.28729866e-03, -7.46294972e-04, ...,
          1.13753909e-02,  2.47983076e-02,  2.61171311e-02],
        [-1.95942633e-02, -2.78775636e-02, -4.36576037e-03, ...,
          2.85091791e-02,  2.66365334e-02,  1.33357383e-02],
        [-3.66703537e-03, -1.05920527e-02, -3.92694585e-02, ...,
         -1.86065044e-02, -8.61970428e-03,  7.87631236e-03],
        ...,
        [-1.11783231e-02, -7.79881887e-03,  1.36085078e-02, ...,
         -1.48652680e-02,  5.23489574e-03, -8.17304198e-03],
        [-2.63752160e-03, -2.06337646e-02, -2.82681715e-02, ...,
         -1.38716251e-02,  4.87648370e-03,  3.14875133e-02],
        [-2.46898038e-03,  8.56530853e-03, -4.07832069e-03, ...,
          6.54941006e-03, -6.84491731e-03,  1.70625560e-03]],

       [[ 3.93751636e-03, -2.27572760e-04,  2.76032910e-02, ...,
         -1.94613300e-02, -1.70523990e-02, -4.53373184e-04],
        [-1.98664819e-03,  1.09958602e-02, -3.39975418e-03, ...,
         -2.23917533e-02,  4.24672337e-03, -4.35387762e-03],
        [ 4.12458694e-03, -9.81098320e-03, -6.25591446e-03, ...,
         -3.74580286e-02, -2.60376167e-02, -2.17441525e-02],
        ...,
        [-1.47966975e-02, -1.89400576e-02,  1.85176264e-02, ...,
         -1.04343714e-02, -1.52482903e-02,  1.29139889e-02],
        [-1.77347427e-03, -9.11834277e-03,  4.02999949e-03, ...,
          6.05269289e-03,  1.35849882e-02, -2.89730597e-02],
        [-6.78223139e-03,  1.99565068e-02,  3.15947365e-03, ...,
          2.51935460e-02, -1.15724122e-02,  3.20255011e-02]],

       [[ 1.63055696e-02,  1.74215510e-02, -3.46944556e-02, ...,
         -8.32989253e-03, -2.42022388e-02,  1.23919342e-02],
        [ 1.85401011e-02,  1.69045851e-02, -1.58525333e-02, ...,
          2.43560430e-02, -1.29926903e-02,  1.58919394e-02],
        [-2.44660657e-02,  6.76360959e-03,  8.01943894e-03, ...,
          1.58627313e-02, -1.85275786e-02,  1.19319158e-02],
        ...,
        [ 2.16113254e-02,  9.25315823e-03,  1.75929349e-02, ...,
         -5.19226771e-03, -6.32650638e-03, -6.29584654e-04],
        [-4.78146831e-03,  3.60262319e-02,  3.00815050e-02, ...,
          2.78831348e-02, -2.36093695e-03,  4.10382217e-03],
        [-5.52481646e-03,  5.89263067e-03,  3.30517478e-02, ...,
         -6.31764112e-03, -1.61627047e-02,  1.60461608e-02]],

       ...,

       [[-9.04484093e-03, -1.79897081e-02,  2.89865909e-03, ...,
         -2.12581158e-02,  5.12820715e-03,  1.56648755e-02],
        [ 2.61469427e-02, -2.90347878e-02,  1.90221183e-02, ...,
          1.21273268e-02,  1.89960301e-02, -2.38205828e-02],
        [-2.72416063e-02,  2.46629063e-02, -1.34669114e-02, ...,
         -1.13853868e-02, -1.00060506e-02, -2.72662155e-02],
        ...,
        [-1.62700787e-02, -8.49730894e-03, -2.54341289e-02, ...,
         -2.48169042e-02, -1.25574963e-02, -2.42359154e-02],
        [ 2.13282369e-02, -4.37776651e-03,  4.21859673e-04, ...,
          3.10579501e-02,  5.24469744e-03,  3.07759270e-03],
        [-2.94398959e-03, -1.46313990e-02, -2.04151347e-02, ...,
         -1.86438411e-02,  4.94564977e-03, -9.87050775e-03]],

       [[ 2.97537297e-02,  1.53374048e-02, -1.25517985e-02, ...,
         -9.94399376e-03, -2.82902736e-02,  1.01959044e-02],
        [ 1.40759777e-02, -5.26282098e-03,  8.26145333e-05, ...,
          3.43081504e-02, -1.55504392e-02,  1.32694170e-02],
        [ 1.77028701e-02, -7.66274007e-03, -9.55448952e-04, ...,
          6.59988960e-04,  1.82383489e-02,  4.63279244e-03],
        ...,
        [-1.62521042e-02, -3.19786044e-03,  4.57780249e-03, ...,
          4.71911579e-03, -9.91407083e-04,  9.34399106e-03],
        [-3.96731263e-03, -1.09405182e-02, -3.83635089e-02, ...,
         -3.71799036e-03, -8.71314295e-03,  2.59737819e-02],
        [ 1.22197708e-02, -2.19771229e-02, -1.61560960e-02, ...,
         -9.13032703e-03,  2.35091727e-02,  9.79705714e-03]],

       [[ 1.10609764e-02, -1.48002803e-02,  3.26271914e-02, ...,
          4.51339316e-03,  4.88888612e-03, -6.90683722e-03],
        [-1.19613595e-02, -5.73471421e-03,  1.01477923e-02, ...,
          1.36237890e-02, -1.14414431e-02, -1.53825097e-02],
        [ 4.29688196e-04, -1.28705315e-02, -6.93672523e-03, ...,
          3.83401550e-02, -6.15896098e-03, -2.81123295e-02],
        ...,
        [-2.06364747e-02, -2.94172280e-02,  5.06030489e-03, ...,
          1.07991118e-02,  2.43222043e-02, -2.24265107e-03],
        [ 1.28439795e-02,  5.00303833e-03, -1.16981091e-02, ...,
         -3.58831063e-02,  1.42811108e-02,  1.55573469e-02],
        [ 2.25822087e-02,  2.57221493e-03, -2.31787134e-02, ...,
          1.55691607e-02,  2.79259644e-02, -1.13139730e-02]]],
      dtype=float32)>, <tf.Variable 'transformer/layer_13/self_attention/query/bias:0' shape=(16, 64) dtype=float32, numpy=
array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>, <tf.Variable 'transformer/layer_7/output_layer_norm/beta:0' shape=(1024,) dtype=float32, numpy=array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)>, <tf.Variable 'transformer/layer_6/self_attention/attention_output/kernel:0' shape=(16, 64, 1024) dtype=float32, numpy=
array([[[-0.01926116, -0.02581214, -0.01691924, ...,  0.02177717,
          0.00437421, -0.01484654],
        [-0.00483676, -0.01899222, -0.03808449, ..., -0.00776857,
         -0.03234192,  0.03293049],
        [-0.02258245, -0.0162634 , -0.00115463, ...,  0.00143799,
          0.0057206 ,  0.0157723 ],
        ...,
        [-0.02311632, -0.01829186,  0.02456315, ..., -0.02652756,
         -0.01561752,  0.00135651],
        [-0.01464891, -0.00483854, -0.0294168 , ..., -0.01153719,
         -0.00388006, -0.02701326],
        [ 0.00462275,  0.01227286, -0.01257814, ...,  0.02017575,
          0.00135482,  0.03111667]],

       [[-0.00345813, -0.02243073, -0.00790957, ...,  0.01077976,
          0.0154425 ,  0.00567263],
        [-0.0024349 ,  0.00339421, -0.00620864, ..., -0.00596956,
          0.01095462,  0.00057763],
        [-0.03843576, -0.00430462,  0.00860052, ..., -0.02472733,
          0.01325304, -0.02895707],
        ...,
        [ 0.00681694,  0.02758551,  0.00211648, ..., -0.03778506,
         -0.00450762, -0.00512034],
        [ 0.02347038,  0.00391411, -0.00424433, ...,  0.02015732,
          0.00218427, -0.02427824],
        [-0.01065046, -0.00637999,  0.01763594, ..., -0.01448331,
         -0.0036583 ,  0.01037558]],

       [[-0.00667398, -0.01820852, -0.01226014, ...,  0.00209914,
         -0.02094156, -0.00676672],
        [ 0.00162052,  0.01776319,  0.01398128, ..., -0.00907692,
         -0.00388687,  0.01299446],
        [-0.00807584, -0.00067609, -0.00424356, ...,  0.0059718 ,
         -0.00687934, -0.01417884],
        ...,
        [ 0.00145818,  0.02859374,  0.0278358 , ...,  0.03342547,
          0.00991547,  0.00195541],
        [-0.02925571,  0.0073322 ,  0.02829467, ...,  0.02551473,
          0.00295835, -0.0065371 ],
        [-0.01032934,  0.01278106,  0.00110712, ..., -0.02127541,
          0.00730786,  0.02877993]],

       ...,

       [[-0.02577817,  0.01019541, -0.01973452, ..., -0.01246672,
         -0.01314142,  0.00148499],
        [ 0.00232454, -0.01730687,  0.03866868, ...,  0.01679878,
         -0.02230117,  0.00667836],
        [ 0.03894295,  0.02490243, -0.02343699, ...,  0.00672477,
         -0.02471217,  0.00196455],
        ...,
        [-0.03014686, -0.00252161, -0.01211892, ..., -0.01147484,
         -0.03614058,  0.00026727],
        [-0.03330056,  0.03054802,  0.02207215, ..., -0.01626576,
          0.00017831,  0.0221157 ],
        [-0.01402209,  0.01313726, -0.01877153, ...,  0.00252694,
          0.00376117, -0.0037953 ]],

       [[ 0.02727202,  0.00103069, -0.0017794 , ...,  0.02827977,
          0.03379859,  0.02364799],
        [-0.00983853,  0.00891189,  0.00262375, ...,  0.01161016,
          0.0140891 , -0.02564525],
        [ 0.02641802, -0.02037652, -0.02852959, ..., -0.00713674,
         -0.01724241, -0.00673046],
        ...,
        [-0.00911743, -0.00760937, -0.03249859, ...,  0.02778009,
          0.02874444,  0.00113215],
        [ 0.019421  ,  0.02043756,  0.00777941, ...,  0.0158096 ,
         -0.01157895, -0.00499131],
        [ 0.00775679, -0.01572668, -0.02420167, ..., -0.00416259,
         -0.00752555,  0.00126749]],

       [[-0.02128324,  0.0054232 , -0.02459579, ..., -0.02981454,
          0.00673519, -0.02338108],
        [ 0.01630158,  0.01276715,  0.00346541, ...,  0.00667585,
          0.02124424,  0.01437753],
        [ 0.01345402, -0.00410406,  0.00895389, ...,  0.02973471,
          0.00470933,  0.01148221],
        ...,
        [ 0.01066411, -0.01091116, -0.02952382, ...,  0.00689746,
          0.01769771,  0.00205382],
        [-0.00841156, -0.02792833, -0.01494418, ..., -0.00806896,
          0.01737072, -0.00875212],
        [-0.00339201,  0.01418848, -0.02446402, ...,  0.01967292,
          0.00351048,  0.00075864]]], dtype=float32)>, <tf.Variable 'transformer/layer_12/self_attention/query/kernel:0' shape=(1024, 16, 64) dtype=float32, numpy=
array([[[ 0.00820519,  0.00556104, -0.03520906, ..., -0.01276917,
         -0.00237448,  0.03581169],
        [-0.01405027, -0.01642389, -0.00139882, ..., -0.01082634,
          0.0067738 ,  0.01058197],
        [ 0.00582847,  0.01248815,  0.00014584, ...,  0.01110127,
         -0.01904017, -0.00452899],
        ...,
        [ 0.00139514,  0.01199533, -0.02865457, ...,  0.00532562,
         -0.01786583, -0.01998562],
        [ 0.02966293, -0.03491584,  0.01987746, ..., -0.00943979,
         -0.0008496 , -0.0048412 ],
        [-0.03317754,  0.00414438, -0.0224281 , ..., -0.01781042,
         -0.00227343, -0.00223483]],

       [[-0.03890119,  0.01868228, -0.01460115, ..., -0.0227309 ,
          0.01327651, -0.01123714],
        [-0.00516326,  0.0031741 , -0.02225364, ..., -0.00533528,
         -0.01469945,  0.02412714],
        [ 0.0082925 , -0.01146957, -0.00369313, ..., -0.00666511,
         -0.02290427, -0.01737839],
        ...,
        [ 0.01149041, -0.01458356,  0.00918481, ...,  0.00814757,
         -0.01066538, -0.02188969],
        [ 0.01937021,  0.0041285 , -0.03370753, ...,  0.02532527,
          0.02096462, -0.01443405],
        [-0.00647815,  0.00213918,  0.003843  , ...,  0.03920846,
          0.00938249,  0.00057187]],

       [[-0.01025783,  0.03311959,  0.01393434, ..., -0.01021597,
         -0.013329  ,  0.00162527],
        [ 0.01090853, -0.00224574, -0.00441354, ...,  0.01385783,
         -0.01262838, -0.01048279],
        [ 0.0142009 , -0.01245682, -0.01106378, ..., -0.01153976,
         -0.00904852,  0.00321337],
        ...,
        [-0.0152079 , -0.00964009, -0.01698693, ..., -0.01513894,
         -0.00853425, -0.03766688],
        [-0.00782203, -0.01868798, -0.03331938, ...,  0.00729945,
          0.02561228,  0.01166588],
        [-0.01895717, -0.02025606,  0.01296986, ..., -0.01649329,
         -0.01624895,  0.00255886]],

       ...,

       [[-0.00159631,  0.01096437, -0.01758203, ...,  0.01515571,
          0.0188272 ,  0.02431137],
        [-0.01442832, -0.01078523,  0.01026469, ...,  0.00159831,
          0.0057012 ,  0.01327436],
        [-0.00460368, -0.01818132,  0.01694743, ...,  0.01742941,
          0.0174509 , -0.01483542],
        ...,
        [-0.00969402, -0.0009088 , -0.02185677, ..., -0.0039718 ,
          0.01912145, -0.02113241],
        [ 0.00208901,  0.0143426 , -0.00606411, ..., -0.01447239,
         -0.01833089,  0.00448707],
        [ 0.02859312,  0.02672931, -0.01561742, ...,  0.01734222,
          0.02638536, -0.0058621 ]],

       [[ 0.02684768, -0.02394707,  0.01755206, ...,  0.02172939,
          0.00183816,  0.01379502],
        [ 0.01049465,  0.02577281,  0.02081703, ..., -0.02137709,
         -0.014693  , -0.01785739],
        [ 0.01173514,  0.00502269,  0.00459334, ..., -0.02248812,
          0.02684899,  0.02829852],
        ...,
        [ 0.01458145,  0.00924218, -0.02337304, ...,  0.02993462,
          0.0033004 ,  0.00460548],
        [ 0.02038357,  0.02773247,  0.02891531, ..., -0.0016869 ,
         -0.03902851,  0.02185735],
        [-0.00239998, -0.03101329,  0.02368914, ...,  0.01240604,
          0.03824535,  0.00367387]],

       [[-0.03504968,  0.03432155, -0.00584746, ..., -0.00701324,
         -0.01953235,  0.01404917],
        [-0.01253278, -0.01840459,  0.02412507, ...,  0.01461585,
          0.00393024,  0.02657436],
        [ 0.00915098, -0.016221  ,  0.01667734, ..., -0.01022978,
          0.0127708 , -0.00213662],
        ...,
        [ 0.01634759,  0.00506513,  0.01129199, ..., -0.01669799,
         -0.01154061, -0.00046697],
        [ 0.0107781 ,  0.00371943,  0.02528979, ..., -0.0059449 ,
         -0.00563929,  0.00271769],
        [-0.01280742, -0.01593632,  0.02547157, ...,  0.02911181,
         -0.00886866,  0.00439405]]], dtype=float32)>, <tf.Variable 'transformer/layer_1/output/kernel:0' shape=(4096, 1024) dtype=float32, numpy=
array([[ 9.4641419e-03, -3.1907026e-02, -2.4522197e-02, ...,
         7.5436994e-03, -1.8764737e-04, -1.5251144e-02],
       [ 1.6540185e-02,  1.5320491e-02,  1.4395310e-03, ...,
         2.6160716e-03,  6.1162934e-03,  1.9040072e-03],
       [-9.9935802e-04,  2.2325907e-02, -1.8512158e-02, ...,
        -3.4510288e-02, -6.3642869e-03,  8.5741160e-03],
       ...,
       [ 3.2064351e-03,  1.8811554e-02, -2.0851808e-02, ...,
        -6.4827208e-03,  1.3624218e-06,  2.3933601e-02],
       [ 5.7403059e-03,  1.5796823e-02, -2.4835255e-02, ...,
         7.9556098e-03, -5.5343448e-03, -1.2410060e-02],
       [ 1.2220235e-02,  1.7643141e-02,  4.0977160e-03, ...,
         1.5453216e-02, -1.2474789e-02,  1.3037729e-02]], dtype=float32)>]