ray-project / ray

Ray is an AI compute engine. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
34.28k stars 5.82k forks source link

[Tune] OOM error #33423

Closed dangalea closed 1 year ago

dangalea commented 1 year ago

What happened + What you expected to happen

I am trying to run a HPO run on 30 nodes of 2 GPUs each, i.e. a total of 60 GPUs and each node has 72 CPUs. Unfortunately, I am running into OOM issues from Ray.

I am logging my results to wandb and from there I can see that the Process Memory Available (non-swap) and System Memory Utilization are going to zero, so it is not a GPU memory issue. When running any of the failed runs as a standalone, I do not get any issues, so I think that Ray is processing data (via my dataloaders) differently than expected. Otherwise, it has some overhead which I am not accounting for. I have tried varying the batch size but that doesn't solve the problem. I have tried running with and without the max_concurrent flag but the issue still persists. That being said, Ray's object_store_memory (obtained via ray status) is non-zero (but <150GB) when the flag is set, otherwise it is zero. Would you be able to help?

Versions / Dependencies

_libgcc_mutex 0.1 conda_forge conda-forge
_openmp_mutex 4.5 2_kmp_llvm conda-forge
aiosignal 1.3.1 pypi_0 pypi
attrs 22.2.0 pypi_0 pypi
blas 1.0 mkl
bottleneck 1.3.5 py310ha9d4c09_0 anaconda
brotli 1.0.9 h5eee18b_7
brotli-bin 1.0.9 h5eee18b_7
brotlipy 0.7.0 py310h7f8727e_1002
bzip2 1.0.8 h7b6447c_0
c-ares 1.18.1 h7f8727e_0
ca-certificates 2022.12.7 ha878542_0 conda-forge
cached-property 1.5.2 hd8ed1ab_1 conda-forge
cached_property 1.5.2 pyha770c72_1 conda-forge
cartopy 0.18.0 py310h95ad73f_2
cdsapi 0.5.1 pypi_0 pypi
certifi 2022.12.7 pyhd8ed1ab_0 conda-forge
cf-plot 3.1.28 pyhd8ed1ab_0 conda-forge
cf-python 3.13.1 py310h5764c6d_0 conda-forge
cfdm 1.9.0.4 py310hff52083_1 conda-forge
cffi 1.15.1 py310h74dc2b5_0
cftime 1.6.2 py310hde88566_1 conda-forge
cfunits 3.3.5 pyhd8ed1ab_0 conda-forge
charset-normalizer 2.0.4 pyhd3eb1b0_0
click 8.1.3 pypi_0 pypi
cloudpickle 2.2.0 pypi_0 pypi
cryptography 38.0.1 py310h9ce1e76_0
cuda 11.7.1 0 nvidia
cuda-cccl 11.7.91 0 nvidia
cuda-command-line-tools 11.7.1 0 nvidia
cuda-compiler 11.7.1 0 nvidia
cuda-cudart 11.7.99 0 nvidia
cuda-cudart-dev 11.7.99 0 nvidia
cuda-cuobjdump 11.7.91 0 nvidia
cuda-cupti 11.7.101 0 nvidia
cuda-cuxxfilt 11.7.91 0 nvidia
cuda-demo-suite 11.8.86 0 nvidia
cuda-documentation 11.8.86 0 nvidia
cuda-driver-dev 11.7.99 0 nvidia
cuda-gdb 11.8.86 0 nvidia
cuda-libraries 11.7.1 0 nvidia
cuda-libraries-dev 11.7.1 0 nvidia
cuda-memcheck 11.8.86 0 nvidia
cuda-nsight 11.8.86 0 nvidia
cuda-nsight-compute 11.8.0 0 nvidia
cuda-nvcc 11.7.99 0 nvidia
cuda-nvdisasm 11.8.86 0 nvidia
cuda-nvml-dev 11.7.91 0 nvidia
cuda-nvprof 11.8.87 0 nvidia
cuda-nvprune 11.7.91 0 nvidia
cuda-nvrtc 11.7.99 0 nvidia
cuda-nvrtc-dev 11.7.99 0 nvidia
cuda-nvtx 11.7.91 0 nvidia
cuda-nvvp 11.8.87 0 nvidia
cuda-runtime 11.7.1 0 nvidia
cuda-sanitizer-api 11.8.86 0 nvidia
cuda-toolkit 11.7.1 0 nvidia
cuda-tools 11.7.1 0 nvidia
cuda-visual-tools 11.7.1 0 nvidia
curl 7.85.0 h5eee18b_0
cycler 0.11.0 pyhd3eb1b0_0
dbus 1.13.18 hb2f20db_0
distlib 0.3.6 pypi_0 pypi
esmf 8.4.0 mpi_mpich_h5a1934d_101 conda-forge
esmpy 8.4.0 mpi_mpich_py310h515c5ea_101 conda-forge
expat 2.4.9 h6a678d5_0
ffmpeg 4.3 hf484d3e_0 pytorch
fftw 3.3.9 h27cfd23_1
filelock 3.9.0 pypi_0 pypi
fontconfig 2.13.1 h6c09931_0
fonttools 4.25.0 pyhd3eb1b0_0
freetype 2.12.1 h4a9f257_0
frozenlist 1.3.3 pypi_0 pypi
gds-tools 1.4.0.31 0 nvidia
geos 3.8.0 he6710b0_0
giflib 5.2.1 h7b6447c_0
glib 2.69.1 h4ff587b_1
gmp 6.2.1 h295c915_3
gnutls 3.6.15 he1e5248_0
grpcio 1.51.3 pypi_0 pypi
gst-plugins-base 1.14.0 h8213a91_2
gstreamer 1.14.0 h28cd5cc_2
h5py 3.7.0 nompi_py310h416281c_102 conda-forge
hdf4 4.2.15 h9772cbc_5 conda-forge
hdf5 1.12.2 mpi_mpich_h08b82f9_0 conda-forge
icu 58.2 he6710b0_3
idna 3.4 py310h06a4308_0
intel-openmp 2021.4.0 h06a4308_3561
joblib 1.1.0 pyhd3eb1b0_0 anaconda
jpeg 9e h7f8727e_0
jsonschema 4.17.3 pypi_0 pypi
kiwisolver 1.4.2 py310h295c915_0
krb5 1.19.2 hac12032_0
lame 3.100 h7b6447c_0
lcms2 2.12 h3be6417_0
ld_impl_linux-64 2.38 h1181459_1
lerc 3.0 h295c915_0
libbrotlicommon 1.0.9 h5eee18b_7
libbrotlidec 1.0.9 h5eee18b_7
libbrotlienc 1.0.9 h5eee18b_7
libclang 10.0.1 default_hb85057a_2
libcublas 11.11.3.6 0 nvidia
libcublas-dev 11.11.3.6 0 nvidia
libcufft 10.9.0.58 0 nvidia
libcufft-dev 10.9.0.58 0 nvidia
libcufile 1.4.0.31 0 nvidia
libcufile-dev 1.4.0.31 0 nvidia
libcurand 10.3.0.86 0 nvidia
libcurand-dev 10.3.0.86 0 nvidia
libcurl 7.85.0 h91b91d3_0
libcusolver 11.4.1.48 0 nvidia
libcusolver-dev 11.4.1.48 0 nvidia
libcusparse 11.7.5.86 0 nvidia
libcusparse-dev 11.7.5.86 0 nvidia
libdeflate 1.8 h7f8727e_5
libedit 3.1.20210910 h7f8727e_0
libev 4.33 h7f8727e_1
libevent 2.1.12 h8f2d780_0
libffi 3.3 he6710b0_2
libgcc-ng 12.2.0 h65d4601_19 conda-forge
libgfortran-ng 11.2.0 h00389a5_1
libgfortran5 11.2.0 h1234567_1
libiconv 1.16 h7f8727e_2
libidn2 2.3.2 h7f8727e_0
libllvm10 10.0.1 hbcb73fb_5
libnetcdf 4.8.1 mpi_mpich_h06c54e2_4 conda-forge
libnghttp2 1.46.0 hce63b2e_0
libnpp 11.8.0.86 0 nvidia
libnpp-dev 11.8.0.86 0 nvidia
libnvjpeg 11.9.0.86 0 nvidia
libnvjpeg-dev 11.9.0.86 0 nvidia
libpng 1.6.37 hbc83047_0
libpq 12.9 h16c4e8d_3
libssh2 1.10.0 h8f2d780_0
libstdcxx-ng 12.2.0 h46fd767_19 conda-forge
libtasn1 4.16.0 h27cfd23_0
libtiff 4.4.0 hecacb30_0
libunistring 0.9.10 h27cfd23_0
libuuid 1.0.3 h7f8727e_2
libwebp 1.2.4 h11a3e52_0
libwebp-base 1.2.4 h5eee18b_0
libxcb 1.15 h7f8727e_0
libxkbcommon 1.0.1 hfa300c1_0
libxml2 2.9.14 h74e7548_0
libxslt 1.1.35 h4e12654_0
libzip 1.9.2 hc869a4a_1 conda-forge
libzlib 1.2.13 h166bdaf_4 conda-forge
llvm-openmp 14.0.6 h9e868ea_0
lz4-c 1.9.3 h295c915_1
matplotlib 3.5.2 py310h06a4308_0
matplotlib-base 3.5.2 py310hf590b9c_0
mkl 2021.4.0 h06a4308_640
mkl-service 2.4.0 py310h7f8727e_0
mkl_fft 1.3.1 py310hd6ae3a3_0
mkl_random 1.2.2 py310h00e6091_0
mpi 1.0 mpich conda-forge
mpi4py 3.1.4 py310h37cc914_0 conda-forge
mpich 4.0.3 h846660c_100 conda-forge
msgpack 1.0.4 pypi_0 pypi
munkres 1.1.4 py_0
ncurses 6.3 h5eee18b_3
netcdf-flattener 1.2.0 pyh9f0ad1d_0 conda-forge
netcdf-fortran 4.6.0 mpi_mpich_hd09bd1e_1 conda-forge
netcdf4 1.6.2 nompi_py310h55e1e36_100 conda-forge
nettle 3.7.3 hbbd107a_1
nsight-compute 2022.3.0.22 0 nvidia
nspr 4.33 h295c915_0
nss 3.74 h0370c37_0
numexpr 2.8.3 py310hcea2de6_0 anaconda
numpy 1.23.3 py310hd5efca6_0
numpy-base 1.23.3 py310h8e6c178_0
opencv-python-headless 4.6.0.66 pypi_0 pypi
openh264 2.1.1 h4ff587b_0
openssl 1.1.1s h0b41bf4_1 conda-forge
packaging 21.3 pyhd3eb1b0_0
pandas 1.4.3 py310h6a678d5_0 anaconda
parallelio 2.5.9 mpi_mpich_h50e6f33_101 conda-forge
pcre 8.45 h295c915_0
pillow 9.2.0 py310hace64e9_1
pip 22.2.2 py310h06a4308_0
platformdirs 3.0.0 pypi_0 pypi
ply 3.11 py310h06a4308_0
proj 7.2.0 h277dcde_2 conda-forge
protobuf 3.20.1 pypi_0 pypi
psutil 5.9.4 py310h5764c6d_0 conda-forge
pycparser 2.21 pyhd3eb1b0_0
pyopenssl 22.0.0 pyhd3eb1b0_0
pyparsing 3.0.9 py310h06a4308_0
pyqt 5.15.7 py310h6a678d5_1
pyqt5-sip 12.11.0 pypi_0 pypi
pyrsistent 0.19.3 pypi_0 pypi
pyshp 2.3.1 pyhd8ed1ab_0 conda-forge
pysocks 1.7.1 py310h06a4308_0
python 3.10.0 h12debd9_5
python-dateutil 2.8.2 pyhd3eb1b0_0
python_abi 3.10 2_cp310 conda-forge
pytorch 1.13.0 py3.10_cuda11.7_cudnn8.5.0_0 pytorch
pytorch-cuda 11.7 h67b0de4_0 pytorch
pytorch-model-summary 0.1.1 py_0 conda-forge
pytorch-mutex 1.0 cuda pytorch
pytz 2022.1 py310h06a4308_0 anaconda
pyyaml 6.0 pypi_0 pypi
qt-main 5.15.2 h327a75a_7
qt-webengine 5.15.9 hd2b0992_4
qtwebkit 5.212 h4eab89a_4
ray 2.3.0 pypi_0 pypi
readline 8.2 h5eee18b_0
requests 2.28.1 py310h06a4308_0
scikit-learn 1.1.1 py310h6a678d5_0 anaconda
scipy 1.9.1 py310hd5efca6_0
setuptools 65.5.0 py310h06a4308_0
shapely 1.8.4 py310h81ba7c5_0
sip 6.6.2 py310h6a678d5_0
six 1.16.0 pyhd3eb1b0_1
sqlite 3.39.3 h5082296_0
tabulate 0.9.0 pypi_0 pypi
tempest-extremes 2.2.1 mpi_mpich_h9b66f1e_0 conda-forge
tensorboardx 2.5.1 pypi_0 pypi
threadpoolctl 2.2.0 pyh0d69192_0 anaconda
tk 8.6.12 h1ccaba5_0
toml 0.10.2 pyhd3eb1b0_0
torch-metrics 1.1.7 pypi_0 pypi
torch-summary 1.4.5 pypi_0 pypi
torchaudio 0.13.0 py310_cu117 pytorch
torchmetrics 0.11.0 pypi_0 pypi
torchvision 0.14.0 py310_cu117 pytorch
tornado 6.2 py310h5eee18b_0
tqdm 4.64.1 py310h06a4308_0
typing_extensions 4.3.0 py310h06a4308_0
tzdata 2022e h04d1e81_0
udunits2 2.2.28 hc3e0081_0 conda-forge
urllib3 1.26.12 py310h06a4308_0
virtualenv 20.19.0 pypi_0 pypi
wheel 0.37.1 pyhd3eb1b0_0
xz 5.2.6 h5eee18b_0
yacs 0.1.8 pypi_0 pypi
yaml 0.2.5 h7b6447c_0 anaconda
zlib 1.2.13 h166bdaf_4 conda-forge
zstd 1.5.2 ha4553b6_0

Reproduction script

import ray
from ray import tune
from ray.tune.schedulers import ASHAScheduler, FIFOScheduler
from ray.air import session
from ray.air.integrations.wandb import setup_wandb
from ray.tune.search.basic_variant import BasicVariantGenerator
from ray.tune.search.hyperopt import HyperOptSearch
import torch, random, time, os
import numpy as np
from ray.util.multiprocessing import Pool as ray_pool
from utils import TrackingDataset
from hp_model import UNet

def get_dataloader(input_data_path, type, num_files, batch_size, step):

    if num_files != None:
        dataset = TrackingDataset(input_data_path, type=type, start=0, end=num_files, step=step)
    else:
        dataset = TrackingDataset(input_data_path, type, step=step)
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, pin_memory=True, drop_last=True, num_workers=32)
    return loader

def weighted_mse_loss(input, target, weights):
    weights = torch.Tensor(weights).cuda()
    mse = (input - target) ** 2
    for i in range(len(mse)):
        for j in range(len(weights)):
            mse[i][j] = weights[j] * mse[i][j]
    return torch.mean(mse)

def weighted_mae_loss(input, target, weights):
    weights = torch.Tensor(weights).cuda()
    mae = torch.abs(input - target)
    for i in range(len(mae)):
        for j in range(len(weights)):
            mae[i][j] = weights[j] * mae[i][j]
    return torch.mean(mae)

def train_one_epoch(model, training_loader, optimizer, loss_fn, scaler, loss_config, weights):

    epoch_loss = 0.

    batch_iter = enumerate(training_loader)

    for batch_i, data in batch_iter:

        optimizer.zero_grad()

        inputs, labels = data
        inputs = inputs.cuda()
        labels = labels.cuda()

        try:
            with torch.cuda.amp.autocast(dtype=torch.float16):
                outputs = model(inputs).squeeze()

                if loss_config != "CE":

                    new_labels = torch.zeros_like(outputs)
                    new_labels[:, 0][labels==0] = 1
                    new_labels[:, 1][labels==1] = 1
                    labels = new_labels.cuda()

                    if loss_config == "W-MSE":
                        loss = weighted_mse_loss(outputs.float(), labels.float(), weights)
                    elif loss_config == "W-MAE":
                        loss = weighted_mae_loss(outputs.float(), labels.float(), weights)
                    elif loss_config in ["BCE", "MSE", "MAE"]:
                        loss = loss_fn(outputs.float(), labels.float())
                else:
                    loss = loss_fn(outputs, labels.long())
        except:
            print("Size:", outputs.size(), labels.size(), loss_config)
            return 1e6

        epoch_loss += loss.item() * len(inputs) / len(training_loader)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

    return epoch_loss

def eval_model(model, val_loader, loss_fn, loss_config, weights):

    running_vloss = 0.

    with torch.no_grad():

        val_preds = []
        val_labels = []

        batch_iter = enumerate(val_loader)

        for i, data in batch_iter:

            inputs, labels = data
            inputs = inputs.cuda()
            labels = labels.cuda()

            outputs = model(inputs.float())
            if loss_config != "CE":

                new_labels = torch.zeros_like(outputs)
                new_labels[:, 0][labels==0] = 1
                new_labels[:, 1][labels==1] = 1
                labels = new_labels.cuda()

                if loss_config == "W-MSE":
                    vloss = weighted_mse_loss(outputs.float(), labels.float(), weights).item()
                elif loss_config == "W-MAE":
                    vloss = weighted_mae_loss(outputs.float(), labels.float(), weights).item()
                elif loss_config in ["BCE", "MSE", "MAE"]:
                    vloss = loss_fn(outputs.float(), labels.float()).item()
            else:
                vloss = loss_fn(outputs, labels.long()).item()
            running_vloss += vloss * len(inputs) / len(val_loader)

            for output in outputs.cpu().detach().numpy():
                val_preds.append(output)
            for label in labels.cpu().detach().numpy():
                val_labels.append(label)

    miou_background, miou_ar = mIoU(val_preds, val_labels)

    return miou_background, miou_ar, running_vloss

def mIoU(preds, labels):

    inputs = []
    for i in range(len(preds)):
        inputs.append([preds[i], labels[i]])

    pool = ray_pool(min(len(inputs), 10))
    results = pool.starmap(iou_case, inputs, chunksize=1)
    pool.close()

    iou_background = []
    iou_ar = []
    for result in results:
        iou_bg_case, iou_ar_case = result
        iou_background.append(iou_bg_case)
        iou_ar.append(iou_ar_case)

    return np.mean(iou_background), np.mean(iou_ar)

def iou_case(pred, label):

    pred = np.argmax(pred, axis=0)

    pred = pred.flatten()
    label = label.flatten()

    pred_label = np.arange(pred.shape[0])[pred==0]
    target_label = np.arange(label.shape[0])[label==0]
    num_intersection_label = np.intersect1d(pred_label, target_label).shape[0]
    num_union_label = np.union1d(pred_label, target_label).shape[0]
    iou_bg = num_intersection_label/num_union_label

    pred_label = np.arange(pred.shape[0])[pred==1]
    target_label = np.arange(label.shape[0])[label==1]
    num_intersection_label = np.intersect1d(pred_label, target_label).shape[0]
    num_union_label = np.union1d(pred_label, target_label).shape[0]
    if num_union_label == 0:
        iou_ar = 0
    else:
        iou_ar = num_intersection_label/num_union_label

    return iou_bg, iou_ar

def train_model(config, epochs=15, num_files=None, input_data_path = "/p/lustre1/galea1/full_ar_tracking_dataset/"):

    torch.manual_seed(1)
    random.seed(1)
    np.random.seed(1)    

    batch_size = 8

    config = f_unpack_dict(config)

    num_kernels = config["num_kernels"]
    downsampling_factor = config["downsampling_factor"]
    kernel_multiplier = config["kernel_multiplier"]
    kernel_size = config["kernel_size"]
    num_layers = config["num_layers"]
    batch_norm = config["batch_norm"]
    activation = config["activation_fn"]
    dropout = config["dropout"]
    weight_init = config["weight_init"]
    opt = config["optimizer"]
    loss_config = config["loss"]

    wandb = setup_wandb(api_key="1c0a1297f8e8f62fcd078e1270b77155af322d2f", project = config["project"], group = config["group"], name=opt+"_{:.5f}".format(config["lr"])+"_{:.5f}".format(config["dropout"]))

    model = UNet(num_kernels=num_kernels, in_classes=7, downsampling_factor=downsampling_factor, stride=1, kernel_multiplier=kernel_multiplier, kernel_size=kernel_size, num_layers=num_layers, batch_norm=batch_norm, activation=activation, dropout=dropout, weight_init=weight_init).cuda()

    zeros = 36167075181
    ones = 1703081619
    total = ones+zeros
    weights = [1-zeros/total, 1-ones/total]

    if opt == "SGD":
        optimizer = torch.optim.SGD(model.parameters(), lr=config["lr"])
    elif opt == "SGD-M":
        optimizer = torch.optim.SGD(model.parameters(), lr=config["lr"], momentum=config["mom"])
    elif opt == "Adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"], betas=(config["beta1"], config["beta2"]), eps=config["eps"])
    elif opt == "NAdam":
        optimizer = torch.optim.NAdam(model.parameters(), lr=config["lr"], betas=(config["beta1"], config["beta2"]), eps=config["eps"])
    elif opt == "RMSProp":
        optimizer = torch.optim.RMSprop(model.parameters(), lr=config["lr"], alpha=config["alpha"], eps=config["eps"])
    elif opt == "AdamW":
        optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"], betas=(config["beta1"], config["beta2"]), eps=config["eps"])

    if loss_config == "CE":
        loss_fn = torch.nn.CrossEntropyLoss(weight=torch.Tensor(weights)).cuda()
    elif loss_config == "BCE":
        loss_fn = torch.nn.BCEWithLogitsLoss().cuda()
    elif loss_config == "MSE":
        loss_fn = torch.nn.MSELoss().cuda()
    elif loss_config == "MAE":
        loss_fn = torch.nn.L1Loss().cuda()
    else:
        loss_fn = None

    train_loader = get_dataloader(input_data_path, "train", num_files, batch_size, step=10)
    val_loader = get_dataloader(input_data_path, "val", num_files, batch_size, step=1)

    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(epochs):

        model.train(True)

        train_one_epoch(model, train_loader, optimizer, loss_fn, scaler, loss_config, weights)

        miou_bg, miou_ar, loss = eval_model(model, val_loader, loss_fn, loss_config, weights)

        if opt == "SGD":
            session.report({"loss": loss, "miou_ar": miou_ar, "num_kernels": num_kernels, "downsampling_factor": downsampling_factor, "kernel_multiplier": kernel_multiplier, "kernel_size": kernel_size, "num_layers": num_layers, "batch_norm": batch_norm, "activation": activation, "dropout": dropout, "weight_init": weight_init, "optimizer": opt, "loss_fn": loss_config, "comb_miou": miou_ar+miou_bg})  
            wandb.log(dict(loss=loss, miou_ar=miou_ar, miou_bg=miou_bg, lr=config["lr"], num_kernels=num_kernels, downsampling_factor=downsampling_factor, kernel_multiplier=kernel_multiplier, kernel_size=kernel_size, num_layers=num_layers, batch_norm=batch_norm, activation=activation, dropout=dropout, weight_init=weight_init, optimizer=opt, loss_fn=loss_config, comb_miou=miou_ar+miou_bg))
        elif opt == "SGD-M":
            session.report({"loss": loss, "miou_ar": miou_ar, "num_kernels": num_kernels, "downsampling_factor": downsampling_factor, "kernel_multiplier": kernel_multiplier, "kernel_size": kernel_size, "num_layers": num_layers, "batch_norm": batch_norm, "activation": activation, "dropout": dropout, "weight_init": weight_init, "optimizer": opt, "loss_fn": loss_config, "momentum": config["mom"], "comb_miou": miou_ar+miou_bg})  
            wandb.log(dict(loss=loss, miou_ar=miou_ar, miou_bg=miou_bg, lr=config["lr"], num_kernels=num_kernels, downsampling_factor=downsampling_factor, kernel_multiplier=kernel_multiplier, kernel_size=kernel_size, num_layers=num_layers, batch_norm=batch_norm, activation=activation, dropout=dropout, weight_init=weight_init, optimizer=opt, loss_fn=loss_config, momentum=config["mom"], comb_miou=miou_ar+miou_bg))
        elif opt == "Adam":
            session.report({"loss": loss, "miou_ar": miou_ar, "num_kernels": num_kernels, "downsampling_factor": downsampling_factor, "kernel_multiplier": kernel_multiplier, "kernel_size": kernel_size, "num_layers": num_layers, "batch_norm": batch_norm, "activation": activation, "dropout": dropout, "weight_init": weight_init, "optimizer": opt, "loss_fn": loss_config, "beta1": config["beta1"], "beta2": config["beta2"], "eps": config["eps"], "comb_miou": miou_ar+miou_bg})  
            wandb.log(dict(loss=loss, miou_ar=miou_ar, miou_bg=miou_bg, lr=config["lr"], num_kernels=num_kernels, downsampling_factor=downsampling_factor, kernel_multiplier=kernel_multiplier, kernel_size=kernel_size, num_layers=num_layers, batch_norm=batch_norm, activation=activation, dropout=dropout, weight_init=weight_init, optimizer=opt, loss_fn=loss_config, beta1 = config["beta1"], beta2 = config["beta2"], eps = config["eps"], comb_miou=miou_ar+miou_bg))
        elif opt == "NAdam":
            session.report({"loss": loss, "miou_ar": miou_ar, "num_kernels": num_kernels, "downsampling_factor": downsampling_factor, "kernel_multiplier": kernel_multiplier, "kernel_size": kernel_size, "num_layers": num_layers, "batch_norm": batch_norm, "activation": activation, "dropout": dropout, "weight_init": weight_init, "optimizer": opt, "loss_fn": loss_config, "beta1": config["beta1"], "beta2": config["beta2"], "eps": config["eps"], "comb_miou": miou_ar+miou_bg})  
            wandb.log(dict(loss=loss, miou_ar=miou_ar, miou_bg=miou_bg, lr=config["lr"], num_kernels=num_kernels, downsampling_factor=downsampling_factor, kernel_multiplier=kernel_multiplier, kernel_size=kernel_size, num_layers=num_layers, batch_norm=batch_norm, activation=activation, dropout=dropout, weight_init=weight_init, optimizer=opt, loss_fn=loss_config, beta1 = config["beta1"], beta2 = config["beta2"], eps = config["eps"], comb_miou=miou_ar+miou_bg))
        elif opt == "RMSProp":
            session.report({"loss": loss, "miou_ar": miou_ar, "num_kernels": num_kernels, "downsampling_factor": downsampling_factor, "kernel_multiplier": kernel_multiplier, "kernel_size": kernel_size, "num_layers": num_layers, "batch_norm": batch_norm, "activation": activation, "dropout": dropout, "weight_init": weight_init, "optimizer": opt, "loss_fn": loss_config, "alpha": config["alpha"], "eps": config["eps"], "comb_miou": miou_ar+miou_bg})  
            wandb.log(dict(loss=loss, miou_ar=miou_ar, miou_bg=miou_bg, lr=config["lr"], num_kernels=num_kernels, downsampling_factor=downsampling_factor, kernel_multiplier=kernel_multiplier, kernel_size=kernel_size, num_layers=num_layers, batch_norm=batch_norm, activation=activation, dropout=dropout, weight_init=weight_init, optimizer=opt, loss_fn=loss_config, alpha = config["alpha"], eps = config["eps"], comb_miou=miou_ar+miou_bg))
        elif opt == "AdamW":
            session.report({"loss": loss, "miou_ar": miou_ar, "num_kernels": num_kernels, "downsampling_factor": downsampling_factor, "kernel_multiplier": kernel_multiplier, "kernel_size": kernel_size, "num_layers": num_layers, "batch_norm": batch_norm, "activation": activation, "dropout": dropout, "weight_init": weight_init, "optimizer": opt, "loss_fn": loss_config, "beta1": config["beta1"], "beta2": config["beta2"], "eps": config["eps"], "comb_miou": miou_ar+miou_bg})  
            wandb.log(dict(loss=loss, miou_ar=miou_ar, miou_bg=miou_bg, lr=config["lr"], num_kernels=num_kernels, downsampling_factor=downsampling_factor, kernel_multiplier=kernel_multiplier, kernel_size=kernel_size, num_layers=num_layers, batch_norm=batch_norm, activation=activation, dropout=dropout, weight_init=weight_init, optimizer=opt, loss_fn=loss_config, beta1 = config["beta1"], beta2 = config["beta2"], eps = config["eps"], comb_miou=miou_ar+miou_bg))

        time.sleep(30)

def f_unpack_dict(dct):

    res = {}
    for (k, v) in dct.items():
        if isinstance(v, dict):
            res = {**res, **f_unpack_dict(v)}
        else:
            res[k] = v

    return res

if __name__ == "__main__":

    ray.init(address="auto")

    gpus_per_trial = 1
    num_samples = 50

    input_data_path = "/p/lustre1/galea1/full_ar_tracking_dataset/"
    num_files = None

    config = {
        "num_kernels": tune.choice([2, 4, 6, 8, 10, 12, 14]),
        "downsampling_factor": tune.choice([2, 3, 4]),
        "kernel_multiplier": tune.choice([2, 3, 4]),
        "kernel_size": tune.choice([1, 2, 3, 4]),
        "num_layers": tune.choice([1, 2, 3, 4]),
        "batch_norm": tune.choice([0, 1]),
        "activation_fn": tune.choice(["ReLU", "Leaky ReLU", "Tanh", "Softmax"]),
        "dropout": tune.uniform(0, 0.5),
        "weight_init": tune.choice(["He", "Xavier", "Uniform", "Zeros"]),
        "optimizer": tune.choice([
            {"optimizer": "SGD", "lr": tune.loguniform(1e-5, 1e-1)},
            {"optimizer": "SGD-M", "lr": tune.loguniform(1e-5, 1e-1), "mom": tune.loguniform(0.1, 0.99)},
            {"optimizer": "Adam", "lr": tune.loguniform(1e-5, 1e-1), "beta1": tune.loguniform(0.5, 0.999), "beta2": tune.loguniform(0.5, 0.999), "eps": tune.loguniform(1e-8, 1e-3)},
            {"optimizer": "NAdam", "lr": tune.loguniform(1e-5, 1e-1), "beta1": tune.loguniform(0.5, 0.999), "beta2": tune.loguniform(0.5, 0.999), "eps": tune.loguniform(1e-8, 1e-3)},
            {"optimizer": "AdamW", "lr": tune.loguniform(1e-5, 1e-1), "beta1": tune.loguniform(0.5, 0.999), "beta2": tune.loguniform(0.5, 0.999), "eps": tune.loguniform(1e-8, 1e-3)},
            {"optimizer": "RMSProp", "lr": tune.loguniform(1e-5, 1e-1), "alpha": tune.loguniform(0.5, 0.999), "eps": tune.loguniform(1e-8, 1e-3)},
        ]),
        "loss": tune.choice(["W-MSE", "W-MAE", "MAE", "MSE", "BCE", "CE"]),
        "wandb": {
            "project": "AR_Tracking",
            "group": "init_run_small_batch4"
        }
    }

    scheduler = FIFOScheduler()

    tuner = tune.Tuner(
        tune.with_resources(
            tune.with_parameters(train_model),
            resources={"cpu": 32, "gpu": gpus_per_trial}
        ),
        tune_config=tune.TuneConfig(
            metric="comb_miou",
            mode="max",
            scheduler=scheduler,
            num_samples=num_samples,
            search_alg=BasicVariantGenerator(max_concurrent=40)
        ),
        param_space=config
    )
    results = tuner.fit()

    best_result = results.get_best_result("comb_miou", "max")

    print("Best trial config: {}".format(best_result.config))
    print("Best trial final validation loss: {}".format(best_result.metrics["loss"]))
    print("Best trial final validation accuracy: {}".format(best_result.metrics["miou_ar"]))

Issue Severity

High: It blocks me from completing my task.

dangalea commented 1 year ago

Bumoing this @matthewdeng @justinvyu

dangalea commented 1 year ago

I changed my eval_model() method to not process the testing data as whole. Instead, this now looks like this:

def eval_model(model, val_loader, loss_fn, loss_config, weights):

    running_vloss = 0.

    all_miou_bg = []
    all_miou_ar = []

    with torch.no_grad():

        batch_iter = enumerate(val_loader)

        pool = ray_pool(10)

        for i, data in batch_iter:

            val_preds = []
            val_labels = []

            inputs, labels = data
            inputs = inputs.cuda()
            labels = labels.cuda()

            outputs = model(inputs.float())
            if loss_config != "CE":

                new_labels = torch.zeros_like(outputs)
                new_labels[:, 0][labels==0] = 1
                new_labels[:, 1][labels==1] = 1
                labels = new_labels.cuda()

                if loss_config == "W-MSE":
                    vloss = weighted_mse_loss(outputs.float(), labels.float(), weights).item()
                elif loss_config == "W-MAE":
                    vloss = weighted_mae_loss(outputs.float(), labels.float(), weights).item()
                elif loss_config in ["BCE", "MSE", "MAE"]:
                    vloss = loss_fn(outputs.float(), labels.float()).item()
            else:
                vloss = loss_fn(outputs, labels.long()).item()
            running_vloss += vloss * len(inputs) / len(val_loader)

            for output in outputs.cpu().detach().numpy():
                val_preds.append(output)
            for label in labels.cpu().detach().numpy():
                val_labels.append(label)

            miou_background, miou_ar = mIoU(pool, val_preds, val_labels)

            for i in miou_background:
                all_miou_bg.append(i)
            for i in miou_ar:
                all_miou_ar.append(i)

    pool.close()

    return np.mean(all_miou_bg), np.mean(all_miou_ar), running_vloss

def mIoU(pool, preds, labels):

    inputs = []
    for i in range(len(preds)):
        inputs.append([preds[i], labels[i]])

    results = pool.starmap(iou_case, inputs, chunksize=1)

    iou_background = []
    iou_ar = []
    for result in results:
        iou_bg_case, iou_ar_case = result
        iou_background.append(iou_bg_case)
        iou_ar.append(iou_ar_case)

    return iou_background, iou_ar

This seems to have solved the problem as now the whole test dataset is not being stored but is being split up.