kheyer / gpt2_zinc_87m

GPT2 language model trained on 480 SMILES from ZINC
2 stars 1 forks source link

Inquiry about Training Data Input Format for GPT2 Zinc 87m Model #1

Open ByteTora opened 1 month ago

ByteTora commented 1 month ago

Hello, Karl

I hope this message finds you well. I am currently exploring the GPT2 Zinc 87m model you've made available and am very interested in fine-tuning it for a specific application in my research.

After reviewing the documentation you provided, I noticed that the details regarding the input format for training data are not explicitly mentioned. To proceed with my fine-tuning efforts, I would greatly appreciate some guidance on how the training data should be structured and formatted.

Specifically, I am looking for information on:

The expected format of the SMILES strings for training. Any preprocessing steps or tokenization requirements before feeding the data into the model. How to handle special tokens such as 'bos' and 'eos', as mentioned in the warning section of the documentation. I understand that the GPT2TokenizerFast does not automatically add special tokens, even when add_special_tokens=True. Could you please advise on the best practice for manually adding these tokens to the training data?

Additionally, if there are any examples or additional resources that you could share, it would be incredibly helpful.

Thank you very much for your time and assistance. I am excited about the potential of your model and look forward to contributing to the community with my fine-tuned version.

Best regards,

Tora

kheyer commented 1 month ago

Hi tora,

Thanks for your interest in the model.

Preprocessing steps:

  1. Canonicalize SMILES strings
  2. Add bos/eos tokens manually
  3. Tokenize

Here is a basic example:

import datasets
from rdkit import Chem
from transformers import GPT2TokenizerFast

tokenizer = GPT2TokenizerFast.from_pretrained("entropy/gpt2_zinc_87m")

smiles = ['Brc1cc2c(NCc3ccccc3)ncnc2s1',
                'Brc1cc2c(NCc3ccccn3)ncnc2s1',
                'Brc1cc2c(NCc3cccs3)ncnc2s1',
                'Brc1cc2c(NCc3ccncc3)ncnc2s1',
                'Brc1cc2c(Nc3ccccc3)ncnc2s1']

dataset = datasets.Dataset.from_dict({'smiles' : smiles})

def preprocess_smile(row, tokenizer):
    try:
        # canonicalize
        smile = Chem.CanonSmiles(row['smiles'])

        # add special tokens
        smile = tokenizer.bos_token + smile + tokenizer.eos_token

        output = {'smiles' : smile}
    except:
        # if canonicalization fails for invalid smiles
        output = {'smiles' : ''}

    return output

dataset = dataset.map(lambda row: preprocess_smile(row, tokenizer))

# remove SMILES that failed canonicalization 
dataset = dataset.filter(lambda row: row['smiles'] != '')

def tokenization(row):
    # possibly add max length/truncation or other tokenizer arguments
    return tokenizer(row["smiles"])

dataset = dataset.map(tokenization)
dataset.save_to_disk('smiles_dataset_tokenized.hf')

Training is pretty standard. Make sure to use DataCollatorForLanguageModeling as the collator. I would also suggest fine-tuning in fp16 as the base model was trained in fp16. Here is an example of the code used to train from scratch - you can adapt it for fine-tuning.

import datasets
from transformers import DataCollatorForLanguageModeling
from transformers import GPT2TokenizerFast, GPT2Config

dataset = datasets.Dataset.load_from_disk('smiles_dataset_tokenized.hf')
dataset = dataset.remove_columns(['smiles', 'attention_mask'])
dataset = dataset.with_format("torch")

data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

tokenizer = GPT2TokenizerFast.from_pretrained("entropy/gpt2_zinc_87m")

config = GPT2Config(
    vocab_size=len(tokenizer),
    n_positions=256,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

model = GPT2LMHeadModel(config)

args = TrainingArguments(...)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=dataset,
)

trainer.train()
ByteTora commented 1 month ago

Got it, thanks! I'll begin the fine-tuning process shortly.😊