isayev / ReLeaSE

Deep Reinforcement Learning for de-novo Drug Design
MIT License
344 stars 134 forks source link

The number of tokens is inconsistent with the tokens provided in `LogP_optimization_demo.ipynb`, and cannot be learned using migration? #30

Open zhouhao-learning opened 5 years ago

zhouhao-learning commented 5 years ago

Hello, when I train a generate model with my own SMILES data, use LogP_optimization_demo.ipynb: tokens = ['<', '>', '#', '%', ')', '(', '+', '-', '/', '.', '1', '0 ', '3', '2', '5', '4', '7', '6', '9', '8', '=', 'A', '@', 'C', 'B', 'F', 'I', 'H', 'O', 'N', 'P', 'S', '[', ']', '\\', 'c', ' e', 'i', 'l', 'o', 'n', 'p', 's', 'r', '\n'], but will get characters outside the tokens list, causing me to fail Continue to use the Transfer learning method to train, so I changed the code as follows during training:

gen_data_path = "data/nueji_data2.csv"
gen_data = GeneratorData(training_data_path=gen_data_path, delimiter='\t', 
                         cols_to_read=[0], keep_header=True, tokens=None)
hidden_size = 1500
stack_width = 1500
stack_depth = 200
layer_type = 'GRU'
lr = 0.001
optimizer_instance = torch.optim.Adadelta

my_generator = StackAugmentedRNN(input_size=gen_data.n_characters, hidden_size=hidden_size,
                                 output_size=gen_data.n_characters, layer_type=layer_type,
                                 n_layers=1, is_bidirectional=False, has_stack=True,
                                 stack_width=stack_width, stack_depth=stack_depth, 
                                 use_cuda=use_cuda, 
                                 optimizer_instance=optimizer_instance, lr=lr)
model_path = './checkpoints/generator/checkpoint_biggest_rnn'
my_generator.load_model(model_path)

But I get the following error:

RuntimeError                              Traceback (most recent call last)
<ipython-input-11-3c9498b26c8c> in <module>()
----> 1 my_generator.load_model(model_path)

/scratch2/hzhou/Drug/generate_smiles/ReLeaSE/release/stackRNN.py in load_model(self, path)
    140         """
    141         weights = torch.load(path)
--> 142         self.load_state_dict(weights)
    143 
    144     def save_model(self, path):

~/.local/lib/python3.6/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
    717         if len(error_msgs) > 0:
    718             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
--> 719                                self.__class__.__name__, "\n\t".join(error_msgs)))
    720 
    721     def parameters(self):

RuntimeError: Error(s) in loading state_dict for StackAugmentedRNN:
    size mismatch for encoder.weight: copying a param of torch.Size([40, 1500]) from checkpoint, where the shape is torch.Size([45, 1500]) in current model.
    size mismatch for decoder.weight: copying a param of torch.Size([40, 1500]) from checkpoint, where the shape is torch.Size([45, 1500]) in current model.
    size mismatch for decoder.bias: copying a param of torch.Size([40]) from checkpoint, where the shape is torch.Size([45]) in current model.

But my data set is very small. Without migration learning, my generation model may not be able to learn the chemical rules of SMILES, so my idea is this: I use the `data/chembl_22_clean_1576904_sorted_std_final.smi'data set to retrain a model, but I customize tokens to define the characters in my data set into token, and finally make it work again. Re-training my data with a pre-training model, is my idea right? I'm not sure.

isayev commented 5 years ago

What kinds of extra characters do you have? You probably need to standardize your SMILEs (remove metals, mixtures, stereochemistry, etc.).

zhouhao-learning commented 5 years ago

@isayev My SMILES contains extra characters a, because the characters contain Na, Ca, what do you mean by standardized SMILES? What do I need to do? Thank you

quangnguyenbn99 commented 4 years ago

hi @zhouhao-learning , Did you solve your problem? I am facing the same issue. I you have the solution please enlighten me.

gmseabra commented 4 years ago

hi @zhouhao-learning , Did you solve your problem? I am facing the same issue. I you have the solution please enlighten me.

Although the question is old, I'm answering it now because it seems it still unresolved...

Basically, the point is that you generally don't want ions (Na+, Ca2+) in your compound library, since they are just counterions to your compound. So, you need to remove those from the your SMILES data before using it.

Take a look at: https://molvs.readthedocs.io/en/latest/

Best.