seyonechithrananda / bert-loves-chemistry

bert-loves-chemistry: a repository of HuggingFace models applied on chemical SMILES data for drug design, chemical modelling, etc.
MIT License
389 stars 60 forks source link
bert chemical-modelling chemical-smiles-data huggingface

ChemBERTa

ChemBERTa: A collection of BERT-like models applied to chemical SMILES data for drug design, chemical modelling, and property prediction. To be presented at Baylearn and the Royal Society of Chemistry's Chemical Science Symposium.

Tutorial
ArXiv ChemBERTa-2 Paper
Arxiv ChemBERTa Paper
Poster
Abstract
BibTex

License: MIT License

Right now the notebooks are all for the RoBERTa model (a variant of BERT) trained on the task of masked-language modelling (MLM). Training was done over 10 epochs until loss converged to around 0.26 on the ZINC 250k dataset. The model weights for ChemBERTA pre-trained on various datasets (ZINC 100k, ZINC 250k, PubChem 100k, PubChem 250k, PubChem 1M, PubChem 10M) are available using HuggingFace. We expect to continue to release larger models pre-trained on even larger subsets of ZINC, CHEMBL, and PubChem in the near future.

This library is currently primarily a set of notebooks with our pre-training and fine-tuning setup, and will be updated soon with model implementation + attention visualization code, likely after the Arxiv publication. Stay tuned!

I hope this is of use to developers, students and researchers exploring the use of transformers and the attention mechanism for chemistry!

Citing Our Work

Please cite ChemBERTa-2's ArXiv paper if you have used these models, notebooks, or examples in any way. The link to the BibTex is available here.

Example

You can load the tokenizer + model for MLM prediction tasks using the following code:

from transformers import AutoModelWithLMHead, AutoTokenizer, pipeline

#any model weights from the link above will work here
model = AutoModelWithLMHead.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")

fill_mask = pipeline('fill-mask', model=model, tokenizer=tokenizer)

Todo: