HICAI-ZJU / KANO

Code and data for the Nature Machine Intelligence paper "Knowledge graph-enhanced molecular contrastive learning with functional prompt".
MIT License
108 stars 24 forks source link

I'm having some problems with the pretrain program #2

Closed YunQingYangZzu closed 1 year ago

YunQingYangZzu commented 1 year ago

Hi, I'm having some problems with pretrain.I tried following the README instructions. python pretrain.py --exp_name 'pre-train' --exp_id 1 --step pretrain But there is an error during the run

Traceback (most recent call last):
  File "/home/yqyang/.pycharm_helpers/pydev/pydevd.py", line 1483, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/home/yqyang/.pycharm_helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/home/yqyang/projects/KANO/pretrain.py", line 35, in <module>
    pretrain(args, logger)
  File "/home/yqyang/projects/KANO/pretrain.py", line 17, in pretrain
    pre_training(args, logger)
  File "/home/yqyang/projects/KANO/chemprop/train/run_training.py", line 415, in pre_training
    emb1 = model1(step, False, batch, None)
  File "/home/yqyang/anaconda3/envs/KANO/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/yqyang/projects/KANO/chemprop/models/model.py", line 112, in forward
    output = self.encoder(*input)
  File "/home/yqyang/anaconda3/envs/KANO/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/yqyang/projects/KANO/chemprop/models/cmpn.py", line 213, in forward
    output = self.encoder.forward(step, batch, features_batch)
  File "/home/yqyang/projects/KANO/chemprop/models/cmpn.py", line 111, in forward
    input_atom = self.act_func(input_atom)
UnboundLocalError: local variable 'input_atom' referenced before assignment

Process finished with exit code 134

I checked the code of CMPN.py and found that if arg.step="pretrain" is not declared for input_atom when the forward function is executed.So an error occurs when executing input_atom = self.act_func(input_atom) I hope you can help me solve this problem, thank you very much!

    def forward(self, step, mol_graph, features_batch=None) -> torch.FloatTensor:

        f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, atom_num, fg_num, f_fgs, fg_scope = mol_graph.get_components()
        if self.args.cuda or next(self.parameters()).is_cuda:
            f_atoms, f_bonds, a2b, b2a, b2revb, f_fgs = (
                    f_atoms.cuda(), f_bonds.cuda(), 
                    a2b.cuda(), b2a.cuda(), b2revb.cuda(), f_fgs.cuda())

        fg_index = [i*13 for i in range(mol_graph.n_mols)]
        fg_indxs = [[i]*133 for i in fg_index]
        fg_indxs = torch.LongTensor(fg_indxs).cuda()
        # a2a = mol_graph.get_a2a().cuda()

        if self.args.step == 'functional_prompt':
            # make sure the prompt exists
            assert self.W_i_atom.prompt_generator
            # Input
            input_atom = self.W_i_atom(f_atoms)  # num_atoms x hidden_size
            input_atom = self.W_i_atom.prompt_generator(input_atom, f_fgs, atom_num, fg_indxs)

        elif self.args.step == 'finetune_add':
            for i in range(len(fg_indxs)):
                f_fgs.scatter_(0, fg_indxs[i:i+1], self.cls)

            target_index = [val for val in range(mol_graph.n_mols) for i in range(13)]
            target_index = torch.LongTensor(target_index).cuda()
            fg_hiddens = scatter_add(f_fgs, target_index, 0)
            fg_hiddens_atom = torch.repeat_interleave(fg_hiddens, torch.tensor(atom_num).cuda(), dim=0)
            fg_out = torch.zeros(1, 133).cuda()
            fg_out = torch.cat((fg_out, fg_hiddens_atom), 0)
            f_atoms += fg_out
            # Input
            input_atom = self.W_i_atom(f_atoms)  # num_atoms x hidden_size

        elif self.args.step == 'finetune_concat':
            for i in range(len(fg_indxs)):
                f_fgs.scatter_(0, fg_indxs[i:i+1], self.cls)

            target_index = [val for val in range(mol_graph.n_mols) for i in range(13)]
            target_index = torch.LongTensor(target_index).cuda()
            fg_hiddens = scatter_add(f_fgs, target_index, 0)
            fg_hiddens_atom = torch.repeat_interleave(fg_hiddens, torch.tensor(atom_num).cuda(), dim=0)
            fg_out = torch.zeros(1, 133).cuda()
            fg_out = torch.cat((fg_out, fg_hiddens_atom), 0)
            f_atoms = torch.cat((fg_out, f_atoms), 1)
            # Input
            input_atom = self.W_i_atom_new(f_atoms)  # num_atoms x hidden_size

        input_atom = self.act_func(input_atom)
ZJU-Fangyin commented 1 year ago

Hi, we apologize for the inconvenience caused by the previous code. To resolve the issue, please add the following code snippet:

else:
    input_atom = self.W_i_atom(f_atoms)  # num_atoms x hidden_size

We have also updated the code on GitHub to reflect this change. Thank you for bringing this to our attention! Please feel free to contact me if you have any further questions.

YunQingYangZzu commented 1 year ago

Thank you very much for your reply and solved my problem