HaoZhongkai / GNOT

31 stars 17 forks source link

Multi-GPU training? #10

Open cfos3120 opened 7 months ago

cfos3120 commented 7 months ago

I have adapted this using simple Data-Parallel from Pytorch, but the model seems to output ``nans sometimes. Have you been able to train this across multiple GPUs on a single node?

HaoZhongkai commented 7 months ago

I guess the NaN might be caused by other reasons. You could check whether there is a channel of your data is 0.

I'm able to train GNOT on multiple GPU nodes. There is a sample code using accelerate here (I'm not sure whether it is compatible with current version):

import pickle
import tqdm
import re
import time
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import dgl
from dgl.dataloading import GraphDataLoader
from torch.utils.data.sampler import SubsetRandomSampler

from accelerate import Accelerator

from collections import OrderedDict
from torch.optim.lr_scheduler import OneCycleLR, StepLR, LambdaLR, CosineAnnealingWarmRestarts, CyclicLR
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, Subset

from gnot.args import get_args
from gnot.data_utils import get_dataset, get_model, get_loss_func, collate_op, MIODataLoader
from gnot.utils import get_seed, get_num_params, timing
from gnot.models.optimizer import Adam, AdamW

EPOCH_SCHEDULERS = ['ReduceLROnPlateau', 'StepLR', 'MultiplicativeLR',
                    'MultiStepLR', 'ExponentialLR', 'LambdaLR',]

def train(args, accelerator, model,
          loss_func,
          metric_func,
          train_loader,
          train_subset_loader,
          valid_loader,
          optimizer,
          lr_scheduler,
          epochs=10,
          writer=None,
          device="cuda",
          patience=10,
          grad_clip=0.999,
          start_epoch: int = 0,
          print_freq: int = 20,
          model_save_path='./models',
          save_mode='state_dict',
          model_name='model.pt',
          result_name='result.pt'):
    loss_train = []
    loss_val = []
    loss_epoch = []
    lr_history = []
    it = 0

    if patience is None or patience == 0:
        patience = epochs
    result = None
    start_epoch = start_epoch
    end_epoch = start_epoch + epochs
    best_val_metric = np.inf
    best_val_epoch = None
    save_mode = 'state_dict' if save_mode is None else save_mode
    stop_counter = 0
    is_epoch_scheduler = any(s in str(lr_scheduler.__class__) for s in EPOCH_SCHEDULERS)

    for epoch in range(start_epoch, end_epoch):
        model.train()
        torch.cuda.empty_cache()
        for _, batch in enumerate(train_loader):

            loss = train_batch(accelerator, model, loss_func, batch, optimizer, lr_scheduler, device,
                               grad_clip=grad_clip)

            loss = np.array(loss)
            loss_epoch.append(loss)
            it += 1
            lr = optimizer.param_groups[0]['lr']
            lr_history.append(lr)
            log = f"epoch: [{epoch + 1}/{end_epoch}]"
            if loss.ndim == 0:  # 1 target loss
                _loss_mean = np.mean(loss_epoch)
                log += " loss: {:.6f}".format(_loss_mean)
            else:
                _loss_mean = np.mean(loss_epoch, axis=0)
                for j in range(len(_loss_mean)):
                    log += " | loss {}: {:.6f}".format(j, _loss_mean[j])
            log += " | current lr: {:.3e}".format(lr)

            if it % print_freq == 0:
                print(log)

            if writer is not None:
                for j in range(len(_loss_mean)):
                    writer.add_scalar("train_loss_{}".format(j), _loss_mean[j],
                                      it)  #### loss 0 seems to be the sum of all loss

        loss_train.append(_loss_mean)
        loss_epoch = []

        train_metric = validate_epoch(model, metric_func, train_subset_loader, device)
        val_result = validate_epoch(model, metric_func, valid_loader, device)

        loss_val.append(val_result["metric"])
        val_metric = val_result["metric"].sum()

        if val_metric < best_val_metric:
            best_val_epoch = epoch
            best_val_metric = val_metric
        #     stop_counter = 0
        #     if save_mode == 'state_dict':
        #         torch.save(model.state_dict(), os.path.join(model_save_path, model_name))
        #     else:
        #         torch.save(model, os.path.join(model_save_path, model_name))
        #     best_model_state_dict = {k: v.to('cpu') for k, v in model.state_dict().items()}
        #     best_model_state_dict = OrderedDict(best_model_state_dict)
        #
        # else:
        #     stop_counter += 1

        if lr_scheduler and is_epoch_scheduler:
            if 'ReduceLROnPlateau' in str(lr_scheduler.__class__):
                lr_scheduler.step(val_metric)
            else:
                lr_scheduler.step()

        if val_result["metric"].size == 1:
            log = "| val metric 0: {:.6f} ".format(val_metric)

        else:
            log = ''
            for i, metric_i in enumerate(val_result['metric']):
                log += '| val metric {} : {:.6f} '.format(i, metric_i)
            # metric_0, metric_1 = val_result["metric"][0], val_result["metric"][1]
            # log = "| val metric 1: {:.6f} | val metric 2: {:.6f} ".format(metric_0, metric_1)

        if writer is not None:
            if val_result["metric"].size == 1:
                writer.add_scalar('val loss {}'.format(metric_func.component), val_metric, epoch)
                writer.add_scalar('train metric {}'.format(metric_func.component), train_metric['metric'], epoch)
            else:
                for i, metric_i in enumerate(val_result['metric']):
                    writer.add_scalar('val loss {}'.format(i), metric_i, epoch)
                    writer.add_scalar('train metric {}'.format(metric_func.component), metric_i, epoch)
        log += "| best val: {:.6f} at epoch {} | current lr: {:.3e}".format(best_val_metric, best_val_epoch + 1, lr)

        desc_ep = ""
        if _loss_mean.ndim == 0:  # 1 target loss
            desc_ep += "| loss: {:.6f}".format(_loss_mean)
        else:
            for j in range(len(_loss_mean)):
                if _loss_mean[j] > 0:
                    desc_ep += "| loss {}: {:.3e}".format(j, _loss_mean[j])

        desc_ep += log
        print(desc_ep)

        result = dict(
            best_val_epoch=best_val_epoch,
            best_val_metric=best_val_metric,
            loss_train=np.asarray(loss_train),
            loss_val=np.asarray(loss_val),
            lr_history=np.asarray(lr_history),
            # best_model=best_model_state_dict,
            optimizer_state=optimizer.state_dict()
        )
        # pickle.dump(result, open(os.path.join(model_save_path, result_name),'wb'))
    return result

# @timing
def train_batch(accelerator, model, loss_func, data, optimizer, lr_scheduler, device, grad_clip=0.999):
    optimizer.zero_grad()

    g, u_p, g_u = data

    g, g_u, u_p = g.to(device), g_u.to(device), u_p.to(device)

    out = model(g, u_p, g_u)

    y_pred, y = out.squeeze(), g.ndata['y'].squeeze()
    loss, reg, _ = loss_func(g, y_pred, y)
    loss = loss + reg

    accelerator.backward(loss)
    # loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    optimizer.step()

    if lr_scheduler:
        lr_scheduler.step()

    return (loss.item(), reg.item())

# @timing
def validate_epoch(model, metric_func, valid_loader, device):
    model.eval()
    metric_val = []
    for i, data in enumerate(valid_loader):
        with torch.no_grad():
            g, u_p, g_u = data
            g, g_u, u_p = g.to(device), g_u.to(device), u_p.to(device)

            out = model(g, u_p, g_u)

            y_pred, y = out.squeeze(), g.ndata['y'].squeeze()
            _, _, metric = metric_func(g, y_pred, y)
            metric_b = accelerator.gather_for_metrics((metric))

            metric_val.append(metric_b.cpu().numpy())

    result = dict(metric=np.concatenate(metric_val).mean(axis=0))

    return result

if __name__ == "__main__":
    args = get_args()

    accelerator = Accelerator(split_batches=False)

    if not args.no_cuda and torch.cuda.is_available():
        device = torch.device('cuda:{}'.format(str(args.gpu)))
    else:
        device = torch.device("cpu")

    device = accelerator.device
    print(device)

    kwargs = {'pin_memory': False} if args.gpu else {}
    get_seed(args.seed, printout=False)

    train_dataset, test_dataset = get_dataset(args)
    # test_dataset = get_dataset(args)

    # train_sampler = SubsetRandomSampler(torch.arange(len(train_dataset)))
    # test_sampler = SubsetRandomSampler(torch.arange(len(test_dataset)))
    # train_loader = GraphDataLoader(train_dataset, sampler=train_sampler, batch_size=args.batch_size, collate_fn=collate_op, drop_last=False)
    # test_loader = GraphDataLoader(test_dataset, sampler=test_sampler, batch_size=args.val_batch_size, collate_fn=collate_op,drop_last=False)
    train_subset = Subset(train_dataset,torch.randperm(int(len(train_dataset)*0.1)))
    train_loader = MIODataLoader(train_dataset, batch_size=args.batch_size, collate_fn=collate_op, shuffle=True,drop_last=False)
    train_subset_loader = MIODataLoader(train_subset, batch_size=args.batch_size, collate_fn=collate_op, shuffle=False,drop_last=False)
    valid_loader = MIODataLoader(test_dataset, batch_size=args.batch_size, collate_fn=collate_op, shuffle=True,drop_last=False)

    args.space_dim = int(re.search(r'\d', args.dataset).group())
    args.normalizer = train_dataset.y_normalizer.to(device) if train_dataset.y_normalizer is not None else None

    #### set random seeds
    get_seed(args.seed)
    torch.cuda.empty_cache()

    loss_func = get_loss_func(name=args.loss_name, args=args, regularizer=True, normalizer=args.normalizer)
    metric_func = get_loss_func(name='rel2', args=args, regularizer=False, normalizer=args.normalizer)

    model = get_model(args)
    model = model.to(device)
    print(f"\nModel: {model.__name__}\t Number of params: {get_num_params(model)}")

    path_prefix = args.dataset + '_{}_'.format(args.component) + model.__name__ + args.comment + time.strftime(
        '_%m%d_%H_%M_%S')
    model_path, result_path = path_prefix + '.pt', path_prefix + '.pkl'

    print(f"Saving model and result in ./data/checkpoints/{model_path}\n")

    if args.use_tb:
        writer_path = './data/logs/' + path_prefix
        log_path = writer_path + '/params.txt'
        writer = SummaryWriter(log_dir=writer_path)
        fp = open(log_path, "w+")
        sys.stdout = fp

    else:
        writer = None
        log_path = None

    print(model)
    # print(config)

    epochs = args.epochs
    lr = args.lr

    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=args.weight_decay,
                                     betas=(args.beta1, args.beta2))
    elif args.optimizer == "AdamW":
        # if hasattr(model, 'configure_optimizers'):
        # print('Using model specified configured optimizer')
        # optimizer = model.configure_optimizers(lr=lr, weight_decay=args.weight_decay,betas=(0.9,0.999))
        # else:
        # optimizer = AdamW(model.parameters(), lr=lr, weight_decay=args.weight_decay)
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=args.weight_decay,
                                      betas=(args.beta1, args.beta2))
    elif args.optimizer == 'CAdam':
        optimizer = Adam(model.parameters(), lr=lr, weight_decay=args.weight_decay, betas=(args.beta1, args.beta2))
    elif args.optimizer == "CAdamW":
        optimizer = AdamW(model.parameters(), lr=lr, weight_decay=args.weight_decay, betas=(args.beta1, args.beta2))
    else:
        raise NotImplementedError

    if args.lr_method == 'cycle':
        print('Using cycle learning rate schedule')
        scheduler = OneCycleLR(optimizer, max_lr=lr, div_factor=1e4, pct_start=0.2, final_div_factor=1e4,steps_per_epoch=len(train_loader), epochs=epochs)
    elif args.lr_method == 'step':
        print('Using step learning rate schedule')
        scheduler = StepLR(optimizer, step_size=args.lr_step_size * len(train_loader), gamma=0.5)
    elif args.lr_method == 'warmup':
        print('Using warmup learning rate schedule')
        scheduler = LambdaLR(optimizer, lambda steps: min((steps + 1) / (args.warmup_epochs * len(train_loader)),np.power(args.warmup_epochs * len(train_loader) / float(steps + 1),0.5)))
    elif args.lr_method == 'restart':
        print('Using cos anneal restart')
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=len(train_loader)* args.lr_step_size, eta_min=0.)
    elif args.lr_method == 'cyclic':
        scheduler = CyclicLR(optimizer, base_lr=1e-5, max_lr = 1e-3, step_size_up = args.lr_step_size * len(train_loader),mode='triangular2',cycle_momentum=False)
    else:
        raise NotImplementedError

    time_start = time.time()

    model, optimizer, scheduler, train_loader, train_subset_loader, valid_loader = accelerator.prepare(model, optimizer, scheduler, train_loader, train_subset_loader, valid_loader)

    result = train(args, accelerator, model, loss_func, metric_func,
                   train_loader, train_subset_loader, valid_loader,
                   optimizer, scheduler,
                   epochs=epochs,
                   grad_clip=args.grad_clip,
                   patience=None,
                   model_name=model_path,
                   model_save_path='./data/checkpoints/',
                   result_name=result_path,
                   writer=writer,
                   device=device)

    print('Training takes {} seconds.'.format(time.time() - time_start))

    # result['args'], result['config'] = args, config
    checkpoint = {'args': args, 'model': model.state_dict(), 'optimizer': optimizer.state_dict()}
    torch.save(checkpoint, os.path.join('./data/checkpoints/{}'.format(model_path)))
    # pickle.dump(checkpoint, open(os.path.join('./../models/checkpoints/{}'.format(model_path), result_path),'wb'))
    # model.load_state_dict(torch.load(os.path.join('./../models/checkpoints/', model_path)))
    model.eval()
    val_metric = validate_epoch(model, metric_func, valid_loader, device)
    print(f"\nBest model's validation metric in this run: {val_metric}")

Later I will clean the code and push to the main branch.

cfos3120 commented 7 months ago

It appears the NaNs occur also on single GPUs but is more likely on multi. It seems to be due to the initialisation and order as changing only the seed can cause a set up to converge or explode to NaN.

Can you elaborate on the zero data channel. I am training it on a lid driven cavity. So the input coordinates have zeros (unit normalised) and the output should have some zero boundary conditions (also normalised).

HaoZhongkai commented 2 days ago

Hi, Not sure whether your problem has been solved. I think the metric like L2 relative error is computed on the whole domain so it doesn't cause NaN. If you have a function that is 0 on all the domain. You could consider using MSE loss or pad them to -1 rather than 0.