kiharalab / GNN_DOVE

Code for "Protein Docking Model Evaluation by Graph Neural Networks"
GNU General Public License v3.0
56 stars 22 forks source link

After getting the iinterface and linterface files, how to input them into the model? #4

Closed sunchuance closed 2 years ago

sunchuance commented 2 years ago

Hello, sorry to bother you, now I have finished the data_processing step with my own decoys, and got the iinterface and linterface files. How to divide it into training set, test set and import it into the model for training?

wang3702 commented 2 years ago

For training/testing dividing, we have described in our paper:

To remove redundancy, we grouped the 58 complexes using sequence alignment and TM-align ([Zhang and Skolnick, 2004](https://www.frontiersin.org/articles/10.3389/fmolb.2021.647915/full#B64)). Two complexes were assigned to the same group if at least one pair of proteins from the two complexes had a TM-score of over 0.5 and sequence identity of 30% or higher. This resulted in 29 groups ([Table 1](https://www.frontiersin.org/articles/10.3389/fmolb.2021.647915/full#T1)). In [Table 1](https://www.frontiersin.org/articles/10.3389/fmolb.2021.647915/full#T1), complexes (PDB IDs) of the same group are shown in lower case in a parenthesis followed by the PDB ID of the representative. These groups were split into four subgroups to perform four-fold cross-validation, where three subsets were used for training, while one testing subset was used for testing the accuracy of the model. Thus, by cross-validation, we have four models tested on four independent testing sets. Among the training set, we used 80% of the complexes (i.e., unique dimers) for training a model and the remaining 20% of the complexes as a validation set, which was used to determine the best hyper-parameter set for training.

We did 4-fold cross-testing, which is because of our small dataset. If your dataset is bigger, you can simply do the TM-Score analysis to cluster different structures and then divide them into train and test. Two important thing is 1) Similar structures to train set should not be put into testing set 2) Different decoys of the same target should always in train or test. We can not put some of them in training while some of them in testing.

sunchuance commented 2 years ago

Thanks a lot for your answer, but I don't know how to put the input file into GNN_Model.py. Like described in the picture below, can you explain in more detail, thanks again! 微信图片_20220219100217

wang3702 commented 2 years ago

For this question, I don't think this repo is a good beginning for you to learn training protocol. I think you should read the common training protocol for image training. Here is a cifar10 training tutorial: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html. Your training code should be similar to the training protocol. You just need to specify a new dataset, which is a modification of the SingleDataset here, and specify the model here, which is GNN_Model. Then the training should be much simpler. Here GNN_Model is the same as all other models like resnet. You should not assign criterion and optimizer instance to the model. For how to put input file to GNN_Model, that's why you need the dataset instance. You just need to simply modify the instance https://github.com/kiharalab/GNN_DOVE/blob/3ac8511014ab005e87e6f4b51f1fd97b14ce7d71/predict/predict_multi_input.py#L92 here. You should be able to feed the input and label to train the model easily.

Here i pasted my code for training. I can't remember if this is my final version or not. But should include most important parts for training

    import random
    random.seed(params['seed'])
    random.shuffle(use_train_list)
    train_list =use_train_list[:int(len(use_train_list) * params['portion'])]
    val_list=use_train_list[int(len(use_train_list)*params['portion']):]
    train_dataset=Dockground_Dataset(data_path, train_list)
    valid_dataset = Dockground_Dataset(data_path, val_list)
    train_sampler = Data_Sampler(train_dataset.weights, len(train_dataset.weights), replacement=True)
    val_sampler = Data_Sampler(valid_dataset.weights, len(valid_dataset.weights), replacement=True)
    train_dataloader = DataLoader(train_dataset, batch_size, shuffle=False,
                                  num_workers=params['num_workers'], sampler=train_sampler, collate_fn=collate_fn)
    valid_dataloader = DataLoader(valid_dataset, batch_size,sampler=val_sampler,
                                  shuffle=False, num_workers=params['num_workers'],
                                  collate_fn=collate_fn)
    test_dataset = Dockground_Dataset(data_path, test_list)
    test_sampler = Data_Sampler(test_dataset.weights, len(test_dataset.weights), replacement=True)
    test_dataloader = DataLoader(test_dataset, batch_size,sampler=test_sampler,
                                  shuffle=False, num_workers=params['num_workers'],
                                  collate_fn=collate_fn)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    train_logger, val_logger,test_logger= init_Logger(log_path)
    loss_fn = nn.BCELoss()
    best_acc = 0
    test_log_path=os.path.join(log_path,"test")
    mkdir(test_log_path)
    for k in range(epoch):
        train_loss, train_Accu = train_GNN(k, model, train_dataloader, train_logger, optimizer, loss_fn, device, params)
        val_loss, val_acc = Val_GNN(k, model, valid_dataloader, val_logger, loss_fn, device, params,log_path)
        test_loss, test_acc = Val_GNN(k, model, test_dataloader, test_logger, loss_fn, device, params, test_log_path)
        is_best = val_acc > best_acc
        best_acc = max(val_acc, best_acc)
        state = {
            'epoch': k + 1,
            'state_dict': model.state_dict(),
            'loss': val_loss,
            'best_roc': best_acc,
            'optimizer': optimizer.state_dict(),
        }
        save_checkpoint(state, is_best, checkpoint=model_path)
        print('Best acc:')
        print(best_acc)
        tmp_name="model_"+str(k)+".pth.tar"
        save_checkpoint(state,False,checkpoint=model_path,filename=tmp_name)
def train_GNN(epoch,model,train_dataloader,train_logger,optimizer,loss_fn,params):
    model.train()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    Loss = AverageMeter()
    Accu1 = AverageMeter()
    Accu2 = AverageMeter()
    iteration = int(len(train_dataloader))
    end_time = time.time()
    for batch_idx, sample in enumerate(train_dataloader):
        H, A1, A2, Y, V, Atom_count = sample
        data_time.update(time.time() - end_time)
        batch_size = H.size(0)
        H, A1, A2, Y, V = H.cuda(), A1.cuda(), A2.cuda(), Y.cuda(), V.cuda()
        pred = model.train_model((H, A1, A2, V, Atom_count), None)
        if batch_idx==0:
            print(pred.size())
            print(pred)
            print(Y)
        loss = loss_fn(pred, Y)
        optimizer.zero_grad()
        #torch.nn.utils.clip_grad_norm_(model.parameters(), params['clip'])
        torch.nn.utils.clip_grad_value_(model.parameters(), params['clip'])
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            prec1 = Calculate_binary_top_accuracy(pred.data, Y.data)
        Accu1.update(prec1, batch_size)
        Loss.update(loss.item(), batch_size)
        batch_time.update(time.time() - end_time)
        end_time = time.time()
        print_str = 'Epoch: [{0}][{1}/{2}]\t Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
                    ' Data {data_time.val:.3f} ({data_time.avg:.3f})\t' \
                    ' Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
                    'Top1 {acc.val:.3f} ({acc.avg:.3f})\t' .format(
            epoch,
            batch_idx + 1,
            iteration,
            batch_time=batch_time,
            data_time=data_time, loss=Loss, acc=Accu1
        )
        print(print_str)
    epoch_record_dict = {
        'epoch': epoch,
        'loss': Loss.avg,
        'lr': optimizer.param_groups[0]['lr'],
        'top1': Accu1.avg,
        'top5': Accu2.avg,
    }
    train_logger.log(epoch_record_dict)
    return Loss.avg, Accu1.avg