microsoft / AI2BMD

AI-powered ab initio biomolecular dynamics simulation
MIT License
383 stars 46 forks source link

Retrain teacher model of Pretrained 3D ViSNet #3

Closed CValse closed 1 year ago

CValse commented 1 year ago

Hi,

I would like to retrain the teacher model (Vanilla ViSNet) used in Pretrained 3D ViSNet using LightningModule. Something similar to:

class Vanilla_ViSNet(LightningModule):

    def __init__(self, hparams, mean=None, std=None):
        super(Vanilla_ViSNet, self).__init__()
        self.save_hyperparameters(hparams)
        self.model = create_model(self.hparams, mean, std)
        # initialize loss collection
        self._reset_losses_dict()
        self._reset_inference_results()

    def configure_optimizers(self):
        optimizer = AdamW(self.model.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        scheduler = ReduceLROnPlateau(optimizer,"min", factor=self.hparams.lr_factor, patience=self.hparams.lr_patience, min_lr=self.hparams.lr_min)  
        lr_scheduler = {"scheduler": scheduler,"monitor": "val_epoch_loss","interval": "epoch", "frequency": 1}
        return [optimizer], [lr_scheduler]  

    def forward(self, data, stage):
        return self.model(data, stage)

    def training_step(self, batch, batch_idx):
        loss_fn = loss_mapping_class[self.hparams.loss_type]
        return self.step(batch, loss_fn, "train")

    def validation_step(self, batch, batch_idx):
        loss_fn = loss_mapping_class[self.hparams.loss_type]
        return self.step(batch, loss_fn, "val")

    def test_step(self, batch, batch_idx):
        loss_fn = loss_mapping_class[self.hparams.loss_type]
        return self.step(batch, loss_fn, "test")

    def step(self, batch, loss_fn, stage):
        ...
        return loss

    def optimizer_step(self, *args, **kwargs):
        optimizer = kwargs["optimizer"] if "optimizer" in kwargs else args[2]
        if self.trainer.global_step < self.hparams.lr_warmup_steps:
            lr_scale = min(
                1.0,
                float(self.trainer.global_step + 1)
                / float(self.hparams.lr_warmup_steps),
            )

            for pg in optimizer.param_groups:
                pg["lr"] = lr_scale * self.hparams.lr
        super().optimizer_step(*args, **kwargs)
        optimizer.zero_grad()

    def validation_epoch_end(self, validation_step_outputs):
        if not self.trainer.sanity_checking:
            # construct dict of logged metrics
            result_dict = {
                "epoch": float(self.current_epoch),
                "lr": self.trainer.optimizers[0].param_groups[0]["lr"],
                "train_epoch_loss": torch.stack(self.losses["train"]).mean(),
                "val_epoch_loss": torch.stack(self.losses["val"]).mean(),
            }

            self.log_dict(result_dict, sync_dist=True)

        self._reset_losses_dict()
        self.results = []

    def test_epoch_end(self, outputs) -> None:
        for key in self.inference_results.keys():
            self.inference_results[key] = torch.cat(self.inference_results[key], dim=0)

    def _reset_losses_dict(self):
        self.losses = {"train": [], "val": []}
        self.losses.update({"train_h": [], "train_top1_acc": [], "val_h": [], "val_top1_acc": []})
        self.losses.update({"train_e": [], "val_e": []})
        self.losses.update({"train_rdkit": [], "val_rdkit": []})
        self.losses.update({"train_eq": [], "val_eq": []})

    def _reset_inference_results(self):
        self.inference_results = {'y_pred': [], 'y_true': []}

Where as hparams I use:

accelerator: gpu
activation: silu
aggr: add
atom_feature:["atomic_num"]
attn_activation: silu
attn_drop_rate: 0.2
batch_size: 32
bond_feature: []
conf: null
cutoff_lower: 0.0
cutoff_upper: 5.0
dataset: OGB_LSC
dataset_root:  /ogb2022-dataset/Pretrained_3D_ViSNet_dataset
distance_influence: both
distance_otf: true
distributed_backend: ddp
drop_path_rate: 0.2
drop_rate: 0.2
early_stopping_patience: 150
embedding_dimension: 256
inference_batch_size: 32
lmax: 1
load_model: null
log_dir:  /ogb2022-logs/Vanilla_ViSNet
loss_e_weight: 1
loss_pos_weight: 15.0
loss_pos_weight_decay_steps: 250000
loss_pos_weight_min: 1.0
loss_type: mse
lr: 0.0001
lr_factor: 0.8
lr_min: 1.0e-07
lr_patience: 15
lr_warmup_steps: 10000
max_num_neighbors: 32
mode: train
model: ViSNetBlock
neighbor_embedding: true
ngpus: -1
num_epochs: 3000
num_heads: 8
num_layers: 9
num_nodes: 1
num_rbf: 64
num_workers: 6
output_dpos: false
output_model: ScalarKD
precision: 32
rbf_type: expnorm
redirect: false
reduce_op: add
reload: 0
save_interval: 1
seed: 1
splits: null
standardize: false
test_interval: 100000
test_size: 0.1
train_size: 0.8
trainable_rbf: false
use_pos_kind: eq
val_size: 0.1
vecnorm_trainable: false
vecnorm_type: max_min
weight_decay: 0.0

and as data I use:

data = DataModule(hparams)
split_idx = data.dataset.get_submit_splits()
data.dataset = data.dataset.index_select(split_idx["train"])
data.prepare_data()
train_ratio = 0.8
validation_ratio = 0.1
test_ratio = 0.1
idx_train, idx_test = train_test_split(split_idx['train'], test_size=1 - train_ratio)
idx_val, idx_test = train_test_split(idx_test, test_size=test_ratio/(test_ratio + validation_ratio)) 
split_idx['train'] = idx_train
split_idx['valid'] = idx_val
split_idx['test-dev'] = idx_test

data.train_dataset = data.dataset.index_select(split_idx["train"])
data.val_dataset = data.dataset.index_select(split_idx["valid"])
data.test_dataset = data.dataset.index_select(split_idx[data.hparams.get("inference_dataset", "valid")])

if data.hparams["standardize"] and data.hparams['task'] == 'train':
    data._standardize()

However, I do not know exactly how to complete the VanillaViSNet class (in particular the step function), as all my attempts have failed.

Thanks!

v-shaoningli commented 1 year ago

Hi. Your Vanilla_ViSNet looks right. You can simply discard the data.dataset.get_submit_splits() and generate the splits using the same random_seed=0. The command of training is CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --conf examples/<your_config> --dataset-root <your_dataset_root> --log-dir <your_log_dir>.

CValse commented 1 year ago

Thanks, but I cannot use the train.py script as is to train the teacher model for several reasons, e.g. there are a few different configurations in the yaml file (which raise the 'Unknown argument' error) and it is necessary to add a ViSNet module which will only be trained on eq structures... or am I missing something?

v-shaoningli commented 1 year ago

Sorry for the inconvenience. The whole codebase of training (teacher) ViSNet for different downstream tasks on MD17/rMD17/MD22/QM9/Molecular3D etc. would be released until the paper on ViSNet is published. We believe it will be released soon but you can modify the OGB-LSC codebase (deleting the verbose argparse and so on to reduce the training). Thanks again for your interest and my apology for the inconvenience.

Eipgen commented 1 year ago

Sorry for the inconvenience. The whole codebase of training (teacher) ViSNet for different downstream tasks on MD17/rMD17/MD22/QM9/Molecular3D etc. would be released until the paper on ViSNet is published. We believe it will be released soon but you can modify the OGB-LSC codebase (deleting the verbose argparse and so on to reduce the training). Thanks again for your interest and my apology for the inconvenience.

Hi, Developer which branch can be used to reappear the performance of MD17 datasets. I check the branch I can't find any keyword about energy or force in now branch?