HelloJocelynLu / t5chem

Transformer-based model for chemical reactions
MIT License
58 stars 14 forks source link

How to predict yield?can you give an example with pre-trianed model? #13

Closed jooewood closed 1 year ago

jooewood commented 1 year ago

Thinks!

HelloJocelynLu commented 1 year ago

Hi, glad to help. For reaction yield prediction, you can start with a sample dataset:

t5chem train --data_dir data/sample/regression/ --output_dir model/ --task_type regression --pretrain models/pretrain/simple/

Feel free to try more datasets as indicated in README! So basically one just need to change --data_dir and --task_type for different tasks.

jooewood commented 1 year ago

Thanks! You mean if I want to predict Yield, I need to train it first? In your README, you mentioned you have provided two pretrained model, one is sample model, one is a multi-task model. In your paper, the USPTO_500_MT dataset has yield data.

I have tryed two ways to predict yield using the pretrained model USPTO_500_MT your provided:

One way is in the README in jupyter, like classification

pretrain_path = "/home/zdx/src/t5chem/model/USPTO_MT_model/models/USPTO_500_MT"

beam_size = 10 # 1, 3, 5, 10
num_seq = 5 # 1, 3, 5, 10

task2prefix = {
    'product': "Product:",
    'reactants': "Reactants:",
    'reagents': "Reagents:",
    'regression': 'Regression:',
    'classification': 'Classification:',
    'yield': 'Yield:'
}

if num_seq > beam_size:
    raise ValueError("num_seq should be smaller than beam_size!")

task_type = "regression" # regression, classification
input_seq = 'COc1ccc2c(c1)C(=O)C(=O)N2CC(=O)OC(C)(C)C>ClCCl.O=C(O)C(F)(F)F>COc1ccc2c(c1)C(=O)C(=O)N2CC(=O)O'

model = T5ForProperty.from_pretrained(pretrain_path)  # for non-seq2seq task
tokenizer = SimpleTokenizer(vocab_file=os.path.join(pretrain_path, 'vocab.pt'))
model.eval()

input_seq = task2prefix[task_type]+input_seq
inputs = tokenizer.encode(input_seq, return_tensors='pt')
outputs = model(inputs)
print(outputs.logits)

Got

tensor([[ 0.3026, -0.7555,  0.0280, -0.8695,  0.1409,  0.9967, -0.4606, -1.1643,
         -0.2884,  0.1678,  3.6822, -0.2532,  0.3625, -0.3296, -1.5424, -0.4938,
         -2.4709,  0.4382,  0.5548,  0.6757,  0.2134,  0.5212, -2.3263,  0.1279,
         -0.6362, -0.5595, -1.1301, -0.5516,  5.4158,  2.7931, -0.0921, -0.7045,
         -2.1020, -0.2331,  5.5720,  0.8861, -0.7813, -0.9672, -0.4967, -0.7034,
          0.0125, -0.1558,  2.2171, -0.8316,  0.1382,  1.7859,  0.0142,  4.6209,
         -0.2086,  0.6761,  2.3287,  0.1810, -0.9961,  0.1662, -0.0508, -3.8013,
          1.6442, -0.2426, -0.4337,  2.2547,  1.1294, -1.3181, -1.6273,  1.1570,
         -0.3464, -0.4175,  2.6690, -5.7702,  1.0352, -0.5221, -0.8707,  0.2566,
          0.4965,  1.0497, -0.0999,  1.5595,  0.0648,  0.0746, -1.8363, -0.1938,
          0.0777,  0.0989, -0.9711,  0.1371, -0.6815, -0.0109,  4.4002, -0.7998,
          0.9913, -0.2225,  1.1177, -0.4997,  0.3207, -1.0187,  1.2267,  6.8115,
         -0.8689, -4.6292, -0.2953,  0.5917, -0.3433,  0.0619, -0.6809,  0.6104,
         -0.0083, -1.8088, -0.7464, -5.9843,  1.3281,  3.0981,  0.3121,  1.4193,
          0.0115, -1.4270, -0.2831, -7.1358, -0.5066, -0.5365, -2.7903, -0.9185,
         -2.8606, -0.4832, -0.1732, -1.0830, -0.3141,  1.3427, -0.6246,  0.1357,
         -1.0018,  0.8321,  0.6573, -2.0086,  1.7869,  1.2680,  0.4235,  0.3970,
          0.6529,  1.1463, -4.0193,  0.5198,  0.0840,  0.6493, -0.0121, -2.0056,
          0.4791, -1.6904,  3.6490, -0.6058, -0.8843,  0.9507, -0.5281, -1.1801,
         -0.1701,  1.8552, -1.0621, -2.6891, -0.8674, -1.2226, -0.5954,  1.5747,
          0.0507, -0.3681, -0.7776,  0.2586, -0.2254, -0.6206, -0.9036, -1.3796,
          1.5997, -4.5002,  0.3014, -0.7040,  1.3396,  0.8421,  0.0531, -0.8873,
          0.0194, -0.5464, -0.2072,  0.0133,  0.5327,  1.1011,  0.1534, -1.1708,
          0.4950,  0.1270, -2.8146, -4.6603,  1.0242, -0.9417,  1.0804,  0.3633,
          0.3177, -2.5605, -2.3539, -0.1273,  4.3921,  0.3010,  0.3306, -0.5685,
         -1.5707,  0.0930, -1.0409, -0.9430,  0.2919,  0.2021,  1.5711, -0.2015,
         -0.0209,  0.4264, -0.1783,  1.6269,  2.0920, -0.5366,  0.4457, -1.5688,
          0.9956, -2.1623, -0.2431, -6.3294,  2.4854,  1.2354, -0.5709,  1.3242,
         -0.8102,  3.7469,  0.3661,  0.4645, -0.1675,  0.2330, -0.6090,  4.7995,
         -0.1416,  0.4393, -1.1996,  6.4246,  0.0370,  0.4198, -0.2443, -3.5064,
          1.0308,  3.2977,  0.7541, -1.0116, -0.3588,  0.8327,  0.4605,  0.7342,
          0.6368, -0.5766, -0.0263, -0.1103, -5.3147, -0.3028, -0.2605,  0.5628]],
       grad_fn=<ViewBackward>)

Any suggestions? about how to predict yield in this way?

Another is in Command line

t5chem predict --data_dir /home/zdx/src/t5chem/data/USPTO_500_MT/data/USPTO_500_MT/Yield --model_dir /home/zdx/src/t5chem/model/USPTO_MT_model/models/USPTO_500_MT --prediction /home/zdx/Downloads/raw_predictions.csv --prefix regression
t5chem predict --data_dir /home/zdx/src/t5chem/data/USPTO_500_MT/data/USPTO_500_MT/Yield --model_dir /home/zdx/src/t5chem/model/USPTO_MT_model/models/USPTO_500_MT --prediction /home/zdx/Downloads/raw_predictions.csv --prefix yield

But I always got the same csv: image

Any suggestions? TKS!

HelloJocelynLu commented 1 year ago

Thanks! You mean if I want to predict Yield, I need to train it first?

Yes.

I have tryed two ways to predict yield using the pretrained model USPTO_500_MT your provided:

As mentioned in paper, seq2seq tasks and regression tasks have different output formats. Therefore, they are not interchangeable (but you can train one from another). That's why both direct testings do not work. I believe that you may have noticed a warning saying that some pre-trained weights are not loaded (and some layers are newly initialized), which is expected when you want to finetune a model, but not direct test on it.

I would recommend to train a yield prediction model by yourself, it only takes a few minutes on high-throughput experimental dataset. It takes longer for USPTO_500_MT, and the performance is worse (as reported in papar).