deepmodeling / deepmd-kit

A deep learning package for many-body potential energy representation and molecular dynamics
https://docs.deepmodeling.com/projects/deepmd/
GNU Lesser General Public License v3.0
1.45k stars 499 forks source link

[BUG] dp compress KeyError: 'filter_type_2/bias_1_2' #658

Closed njzjz closed 3 years ago

njzjz commented 3 years ago

Summary

dp compress cannot compress my model, where an error KeyError: 'filter_type_2/bias_1_2' raised. A possible reason is that my model uses the exclude_types parameter.

Deepmd-kit version, installation way, input file, running commands, error log, etc.

2.0.0.b0, conda

{
    "model": {
        "type_map": [
            "C",
            "H",
            "HW",
            "O",
            "OW",
            "P"
        ],
        "descriptor": {
            "type": "se_a",
            "sel": [
                6,
                11,
                400,
                6,
                200,
                1
            ],
            "rcut_smth": 1.0,
            "rcut": 6.0,
            "neuron": [
                25,
                50,
                100
            ],
            "resnet_dt": false,
            "axis_neuron": 12,
            "exclude_types": [
                [
                    2,
                    2
                ],
                [
                    2,
                    4
                ],
                [
                    4,
                    4
                ]
            ],
            "set_davg_zero": true,
            "seed": 1193825419
        },
        "fitting_net": {
            "neuron": [
                240,
                240,
                240
            ],
            "resnet_dt": true,
            "atom_ener": [
                null,
                null,
                0.0,
                null,
                0.0,
                null
            ],
            "seed": 3795656709
        }
    },
    "learning_rate": {
        "type": "exp",
        "start_lr": 0.001,
        "decay_steps": 400,
        "stop_lr": 5.00e-8
    },
    "loss": {
        "start_pref_e": 0.02,
        "limit_pref_e": 1,
        "start_pref_f": 1000,
        "limit_pref_f": 1,
        "start_pref_v": 0,
        "limit_pref_v": 0
    },
    "training": {
        "numb_steps": 4000,
        "disp_file": "lcurve.out",
        "disp_freq": 1000,
        "numb_test": 1,
        "save_freq": 1000,
        "disp_training": true,
        "time_training": true,
        "profiling": false,
        "profiling_file": "timeline.json",
        "training_data": {
        "systems": [
            "../data.init/MNDO298/C6H11HW260O6OW130P1",
            "../data.init/MNDO298/C6H11HW240O6OW120P1",
            "../data.init/MNDO298/C6H11HW212O6OW106P1",
            "../data.init/MNDO298/C6H11HW256O6OW128P1",
            "../data.init/MNDO298/C6H11HW230O6OW115P1",
            "../data.init/MNDO298/C6H11HW176O6OW88P1",
            "../data.init/MNDO298/C6H11HW262O6OW131P1",
            "../data.init/MNDO298/C6H11HW216O6OW108P1",
            "../data.init/MNDO298/C6H11HW238O6OW119P1",
            "../data.init/MNDO298/C6H11HW242O6OW121P1",
            "../data.init/MNDO298/C6H11HW200O6OW100P1",
            "../data.init/MNDO298/C6H11HW174O6OW87P1",
            "../data.init/MNDO298/C6H11HW244O6OW122P1",
            "../data.init/MNDO298/C6H11HW234O6OW117P1",
            "../data.init/MNDO298/C6H11HW224O6OW112P1",
            "../data.init/MNDO298/C6H11HW208O6OW104P1",
            "../data.init/MNDO298/C6H11HW210O6OW105P1",
            "../data.init/MNDO298/C6H11HW222O6OW111P1",
            "../data.init/MNDO298/C6H11HW226O6OW113P1",
            "../data.init/MNDO298/C6H11HW228O6OW114P1",
            "../data.init/MNDO298/C6H11HW190O6OW95P1",
            "../data.init/MNDO298/C6H11HW252O6OW126P1",
            "../data.init/MNDO298/C6H11HW266O6OW133P1",
            "../data.init/MNDO298/C6H11HW220O6OW110P1",
            "../data.init/MNDO298/C6H11HW182O6OW91P1",
            "../data.init/MNDO298/C6H11HW202O6OW101P1",
            "../data.init/MNDO298/C6H11HW206O6OW103P1",
            "../data.init/MNDO298/C6H11HW248O6OW124P1",
            "../data.init/MNDO298/C6H11HW178O6OW89P1",
            "../data.init/MNDO298/C6H11HW264O6OW132P1",
            "../data.init/MNDO298/C6H11HW236O6OW118P1",
            "../data.init/MNDO298/C6H11HW250O6OW125P1",
            "../data.init/MNDO298/C6H11HW232O6OW116P1",
            "../data.init/MNDO298/C6H11HW218O6OW109P1",
            "../data.init/MNDO298/C6H11HW204O6OW102P1",
            "../data.init/MNDO298/C6H11HW214O6OW107P1",
            "../data.init/MNDO298/C6H11HW246O6OW123P1",
            "../data.init/MNDO298/C6H11HW186O6OW93P1",
            "../data.init/MNDO298/C6H11HW188O6OW94P1",
            "../data.init/MNDO298/C6H11HW180O6OW90P1",
            "../data.init/MNDO298/C6H11HW194O6OW97P1",
            "../data.init/MNDO298/C6H11HW198O6OW99P1",
            "../data.init/MNDO298/C6H11HW254O6OW127P1",
            "../data.init/MNDO298/C6H11HW196O6OW98P1",
            "../data.init/MNDO298/C6H11HW258O6OW129P1",
            "../data.init/MNDO298/C6H11HW184O6OW92P1",
            "../data.init/MNDO298/C6H11HW192O6OW96P1",
            "../data.init/MNDO315/C6H11HW260O6OW130P1",
            "../data.init/MNDO315/C6H11HW240O6OW120P1",
            "../data.init/MNDO315/C6H11HW212O6OW106P1",
            "../data.init/MNDO315/C6H11HW256O6OW128P1",
            "../data.init/MNDO315/C6H11HW230O6OW115P1",
            "../data.init/MNDO315/C6H11HW176O6OW88P1",
            "../data.init/MNDO315/C6H11HW262O6OW131P1",
            "../data.init/MNDO315/C6H11HW216O6OW108P1",
            "../data.init/MNDO315/C6H11HW238O6OW119P1",
            "../data.init/MNDO315/C6H11HW242O6OW121P1",
            "../data.init/MNDO315/C6H11HW200O6OW100P1",
            "../data.init/MNDO315/C6H11HW174O6OW87P1",
            "../data.init/MNDO315/C6H11HW244O6OW122P1",
            "../data.init/MNDO315/C6H11HW234O6OW117P1",
            "../data.init/MNDO315/C6H11HW224O6OW112P1",
            "../data.init/MNDO315/C6H11HW208O6OW104P1",
            "../data.init/MNDO315/C6H11HW210O6OW105P1",
            "../data.init/MNDO315/C6H11HW222O6OW111P1",
            "../data.init/MNDO315/C6H11HW226O6OW113P1",
            "../data.init/MNDO315/C6H11HW228O6OW114P1",
            "../data.init/MNDO315/C6H11HW190O6OW95P1",
            "../data.init/MNDO315/C6H11HW252O6OW126P1",
            "../data.init/MNDO315/C6H11HW220O6OW110P1",
            "../data.init/MNDO315/C6H11HW182O6OW91P1",
            "../data.init/MNDO315/C6H11HW202O6OW101P1",
            "../data.init/MNDO315/C6H11HW206O6OW103P1",
            "../data.init/MNDO315/C6H11HW248O6OW124P1",
            "../data.init/MNDO315/C6H11HW178O6OW89P1",
            "../data.init/MNDO315/C6H11HW264O6OW132P1",
            "../data.init/MNDO315/C6H11HW236O6OW118P1",
            "../data.init/MNDO315/C6H11HW250O6OW125P1",
            "../data.init/MNDO315/C6H11HW232O6OW116P1",
            "../data.init/MNDO315/C6H11HW268O6OW134P1",
            "../data.init/MNDO315/C6H11HW218O6OW109P1",
            "../data.init/MNDO315/C6H11HW204O6OW102P1",
            "../data.init/MNDO315/C6H11HW214O6OW107P1",
            "../data.init/MNDO315/C6H11HW246O6OW123P1",
            "../data.init/MNDO315/C6H11HW186O6OW93P1",
            "../data.init/MNDO315/C6H11HW188O6OW94P1",
            "../data.init/MNDO315/C6H11HW180O6OW90P1",
            "../data.init/MNDO315/C6H11HW194O6OW97P1",
            "../data.init/MNDO315/C6H11HW198O6OW99P1",
            "../data.init/MNDO315/C6H11HW254O6OW127P1",
            "../data.init/MNDO315/C6H11HW196O6OW98P1",
            "../data.init/MNDO315/C6H11HW258O6OW129P1",
            "../data.init/MNDO315/C6H11HW184O6OW92P1",
            "../data.init/MNDO315/C6H11HW192O6OW96P1",
            "../data.init/MNDO330/C6H11HW260O6OW130P1",
            "../data.init/MNDO330/C6H11HW240O6OW120P1",
            "../data.init/MNDO330/C6H11HW212O6OW106P1",
            "../data.init/MNDO330/C6H11HW256O6OW128P1",
            "../data.init/MNDO330/C6H11HW230O6OW115P1",
            "../data.init/MNDO330/C6H11HW176O6OW88P1",
            "../data.init/MNDO330/C6H11HW262O6OW131P1",
            "../data.init/MNDO330/C6H11HW216O6OW108P1",
            "../data.init/MNDO330/C6H11HW238O6OW119P1",
            "../data.init/MNDO330/C6H11HW242O6OW121P1",
            "../data.init/MNDO330/C6H11HW200O6OW100P1",
            "../data.init/MNDO330/C6H11HW174O6OW87P1",
            "../data.init/MNDO330/C6H11HW244O6OW122P1",
            "../data.init/MNDO330/C6H11HW234O6OW117P1",
            "../data.init/MNDO330/C6H11HW224O6OW112P1",
            "../data.init/MNDO330/C6H11HW208O6OW104P1",
            "../data.init/MNDO330/C6H11HW210O6OW105P1",
            "../data.init/MNDO330/C6H11HW222O6OW111P1",
            "../data.init/MNDO330/C6H11HW226O6OW113P1",
            "../data.init/MNDO330/C6H11HW228O6OW114P1",
            "../data.init/MNDO330/C6H11HW190O6OW95P1",
            "../data.init/MNDO330/C6H11HW252O6OW126P1",
            "../data.init/MNDO330/C6H11HW266O6OW133P1",
            "../data.init/MNDO330/C6H11HW220O6OW110P1",
            "../data.init/MNDO330/C6H11HW182O6OW91P1",
            "../data.init/MNDO330/C6H11HW202O6OW101P1",
            "../data.init/MNDO330/C6H11HW206O6OW103P1",
            "../data.init/MNDO330/C6H11HW248O6OW124P1",
            "../data.init/MNDO330/C6H11HW178O6OW89P1",
            "../data.init/MNDO330/C6H11HW264O6OW132P1",
            "../data.init/MNDO330/C6H11HW236O6OW118P1",
            "../data.init/MNDO330/C6H11HW250O6OW125P1",
            "../data.init/MNDO330/C6H11HW232O6OW116P1",
            "../data.init/MNDO330/C6H11HW268O6OW134P1",
            "../data.init/MNDO330/C6H11HW218O6OW109P1",
            "../data.init/MNDO330/C6H11HW204O6OW102P1",
            "../data.init/MNDO330/C6H11HW214O6OW107P1",
            "../data.init/MNDO330/C6H11HW246O6OW123P1",
            "../data.init/MNDO330/C6H11HW186O6OW93P1",
            "../data.init/MNDO330/C6H11HW188O6OW94P1",
            "../data.init/MNDO330/C6H11HW180O6OW90P1",
            "../data.init/MNDO330/C6H11HW194O6OW97P1",
            "../data.init/MNDO330/C6H11HW198O6OW99P1",
            "../data.init/MNDO330/C6H11HW254O6OW127P1",
            "../data.init/MNDO330/C6H11HW196O6OW98P1",
            "../data.init/MNDO330/C6H11HW258O6OW129P1",
            "../data.init/MNDO330/C6H11HW184O6OW92P1",
            "../data.init/MNDO330/C6H11HW192O6OW96P1"
        ],
        "batch_size": [
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1,
            1
        ]
        },
        "seed": 515671943
    },
    "_comment": "that's all"
}

Steps to Reproduce

dp train input.json
dp freeze
dp compress input.json
Traceback (most recent call last):
  File "/home/jz748/anaconda3/bin/dp", line 10, in <module>
    sys.exit(main())
  File "/home/jz748/anaconda3/lib/python3.7/site-packages/deepmd/entrypoints/main.py", line 352, in main
    compress(**dict_args)
  File "/home/jz748/anaconda3/lib/python3.7/site-packages/deepmd/entrypoints/compress.py", line 104, in compress
    log_path=log_path,
  File "/home/jz748/anaconda3/lib/python3.7/site-packages/deepmd/entrypoints/train.py", line 211, in train
    _do_work(jdata, run_opt)
  File "/home/jz748/anaconda3/lib/python3.7/site-packages/deepmd/entrypoints/train.py", line 261, in _do_work
    model.build(train_data, stop_batch)
  File "/home/jz748/anaconda3/lib/python3.7/site-packages/deepmd/train/trainer.py", line 306, in build
    self.descrpt.enable_compression(self.min_nbor_dist, self.model_param['compress']['model_file'], self.model_param['compress']['table_config'][0], self.model_param['compress']['table_config'][1], self.model_param['compress']['table_config'][2], self.model_param['compress']['table_config'][3])
  File "/home/jz748/anaconda3/lib/python3.7/site-packages/deepmd/descriptor/se_a.py", line 266, in enable_compression
    self.table = DeepTabulate(self.model_file, self.type_one_side)
  File "/home/jz748/anaconda3/lib/python3.7/site-packages/deepmd/utils/tabulate.py", line 66, in __init__
    self.bias = self._get_bias()
  File "/home/jz748/anaconda3/lib/python3.7/site-packages/deepmd/utils/tabulate.py", line 174, in _get_bias
    tensor_value = np.frombuffer(self.filter_variable_nodes["filter_type_" + str(int(ii / self.ntypes)) + "/bias_" + str(layer) + "_" + str(int(ii % self.ntypes))].tensor_content)
KeyError: 'filter_type_2/bias_1_2'
amcadmus commented 3 years ago

We shall add a UT for testing the type exclusion in model compression.