Closed DaniJonesOcean closed 2 months ago
Based on the traceback, the CUDA OOM error occurs during the prediction stage when sampling from the model, specifically within the ar_sample
method. The ar_sample
method calls run_nps_model_ar
, which in turn is running auto-regressive (AR) predictions that consume a lot of memory. This is a plausible scenario, as AR sampling can be memory-intensive.
Reduce the Number of Samples:
If memory is the problem, reducing the number of samples (n_samples
) can significantly decrease memory usage.
Reduce the Subsample Factor:
Lowering ar_subsample_factor
will reduce the computation required for each sample, thereby decreasing memory footprint.
Memory Profiling: Use memory profiling before and after important calls to identify where the memory spikes occur.
Clear CUDA Cache:
Use torch.cuda.empty_cache()
regularly to free up unused memory.
Mixed Precision Training: Implement mixed precision to save memory during operations.
Here's an implementation with these steps:
import logging
import torch
import warnings
import deepsensor.torch
from deepsensor.data import DataProcessor, TaskLoader
from deepsensor.model import ConvNP
from deepsensor.train import Trainer
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import pandas as pd
import numpy as np
from tqdm import tqdm, notebook
from deepsensor.train import set_gpu_default_device
def print_memory_usage():
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.4f} GB")
print(f"Reserved: {torch.cuda.memory_reserved() / 1024**3:.4f} GB")
warnings.filterwarnings("ignore", category=plum.resolver.MethodRedefinitionWarning)
dat15 ='/nfs/turbo/seas-dannes/SST-sensor-placement-input/GLSEA3_NETCDF/GLSEA3_2015.nc'
dat14 ='/nfs/turbo/seas-dannes/SST-sensor-placement-input/GLSEA3_NETCDF/GLSEA3_2014.nc'
dat16 ='/nfs/turbo/seas-dannes/SST-sensor-placement-input/GLSEA3_NETCDF/GLSEA3_2016.nc'
dat = xr.open_mfdataset([dat14, dat15, dat16],
concat_dim='time',
combine='nested',
chunks={'lat': 'auto', 'lon': 'auto'})
mdat = dat.where(np.isnan(dat.sst) == False, -0.009)
climatology = mdat.groupby('time.dayofyear').mean('time')
anomalies = mdat.groupby('time.dayofyear') - climatology
data_processor = DataProcessor(x1_name="lat", x2_name="lon")
anom_ds = data_processor(anomalies)
task_loader = TaskLoader(
context = anom_ds,
target = anom_ds,
)
train_tasks = []
for date in pd.date_range('2015-01-02', '2015-12-31')[::5]:
task = task_loader(date, context_sampling="all", target_sampling="all")
train_tasks.append(task)
val_tasks = []
for date in pd.date_range('2016-01-01', '2016-12-31'):
task = task_loader(date, context_sampling="all", target_sampling="all")
val_tasks.append(task)
set_gpu_default_device()
# Mixed Precision Training
from torch.cuda.amp import GradScaler, autocast
scaler = GradScaler()
model = ConvNP(data_processor, task_loader)
# Training Loop
trainer = Trainer(model, lr=5e-5)
losses = []
val_rmses = []
val_rmse_best = np.inf
for epoch in range(10):
train_tasks = gen_tasks(pd.date_range('2015-01-02', '2015-12-31')[::5], progress=False)
# Log memory usage
print_memory_usage()
batch_losses = []
for task in train_tasks:
optimizer.zero_grad()
with autocast():
loss = model(task)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
batch_losses.append(loss.item())
# Log memory usage
print_memory_usage()
losses.append(np.mean(batch_losses))
val_rmses.append(compute_val_rmse(model, val_tasks))
if val_rmses[-1] < val_rmse_best:
val_rmse_best = val_rmses[-1]
torch.cuda.empty_cache() # Clear cache regularly to prevent OOM
# Prediction
print_memory_usage() # Check GPU memory before prediction
test_task = task_loader("2016-07-19T12:00:00", context_sampling=["all"], seed_override=42)
pred = model.predict(test_task, X_t=anomalies, n_samples=1, ar_sample=True, ar_subsample_factor=20) # Reduced n_samples and ar_subsample_factor to lower memory usage
print_memory_usage() # Check GPU memory after prediction
Reduced n_samples
:
pred = model.predict(test_task, X_t=anomalies, n_samples=1, ar_sample=True, ar_subsample_factor=20)
Memory Usage Logging: Memory usage is logged before and after crucial steps to understand where the memory usage is peaking.
Clearing Cache:
torch.cuda.empty_cache()
is called frequently to clear unused memory.
Mixed Precision Training:
Uses autocast
and GradScaler
for mixed precision training to reduce memory usage during operations.
If the issue persists despite these mitigations, it could be helpful to report the problem to the maintainers of the relevant libraries (deepsensor
, plum
, etc.).
Found a relevant Stackoverflow post:
https://stackoverflow.com/questions/59129812/how-to-avoid-cuda-out-of-memory-in-pytorch
Updated sample code
import logging
import torch
import warnings
import deepsensor.torch
from deepsensor.data import DataProcessor, TaskLoader
from deepsensor.model import ConvNP
from deepsensor.train import Trainer
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import pandas as pd
import numpy as np
from tqdm import tqdm # Changed to terminal-based tqdm
from deepsensor.train import set_gpu_default_device
import plum # Import plum library
def print_memory_usage():
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.4f} GB")
print(f"Reserved: {torch.cuda.memory_reserved() / 1024**3:.4f} GB")
# Define the compute_val_rmse() function
def compute_val_rmse(model, val_tasks):
if not val_tasks: # Check if val_tasks is empty
print("Validation tasks are empty!")
return np.nan # Handle case when val_tasks is empty
errors = []
target_var_ID = task_loader.target_var_IDs[0][0] # assume 1st target set and 1D
for task in np.random.choice(val_tasks, min(50, len(val_tasks)), replace=False):
mean = data_processor.map_array(model.mean(task), target_var_ID, unnorm=True)
true = data_processor.map_array(task["Y_t"][0], target_var_ID, unnorm=True)
errors.extend(np.abs(mean - true))
if errors:
return np.sqrt(np.mean(np.concatenate(errors) ** 2))
else:
return np.nan # Handle case when errors is empty
warnings.filterwarnings("ignore", category=plum.resolver.MethodRedefinitionWarning)
dat15 ='/nfs/turbo/seas-dannes/SST-sensor-placement-input/GLSEA3_NETCDF/GLSEA3_2015.nc'
dat14 ='/nfs/turbo/seas-dannes/SST-sensor-placement-input/GLSEA3_NETCDF/GLSEA3_2014.nc'
dat16 ='/nfs/turbo/seas-dannes/SST-sensor-placement-input/GLSEA3_NETCDF/GLSEA3_2016.nc'
dat = xr.open_mfdataset([dat14, dat15, dat16],
concat_dim='time',
combine='nested',
chunks={'lat': 'auto', 'lon': 'auto'})
mdat = dat.where(np.isnan(dat.sst) == False, -0.009)
climatology = mdat.groupby('time.dayofyear').mean('time')
anomalies = mdat.groupby('time.dayofyear') - climatology
data_processor = DataProcessor(x1_name="lat", x2_name="lon")
anom_ds = data_processor(anomalies)
task_loader = TaskLoader(
context = anom_ds,
target = anom_ds,
)
# Debugging: Print available dates
available_dates = pd.to_datetime(dat['time'].values)
available_dates_set = set(available_dates)
print("Available dates in dataset:", available_dates)
# Filter pd.date_range to only include available dates
def filter_available_dates(dates):
return dates[dates.isin(available_dates_set)]
# Example to filter dates
filtered_train_dates = filter_available_dates(pd.date_range('2015-01-02', '2015-12-31')[::5])
filtered_val_dates = filter_available_dates(pd.date_range('2016-01-01', '2016-12-31'))
print("Filtered train dates:", filtered_train_dates)
print("Filtered val dates:", filtered_val_dates)
def gen_tasks(dates, progress=True):
tasks = []
for date in tqdm(dates, disable=not progress): # Changed to tqdm
task = task_loader(date, context_sampling=["all"], target_sampling="all")
tasks.append(task)
return tasks
train_tasks = gen_tasks(filtered_train_dates, progress=True)
val_tasks = gen_tasks(filtered_val_dates, progress=True)
print(f"Number of training tasks: {len(train_tasks)}")
print(f"Number of validation tasks: {len(val_tasks)}")
set_gpu_default_device()
# Mixed Precision Training
from torch.cuda.amp import GradScaler, autocast
scaler = GradScaler()
model = ConvNP(data_processor, task_loader)
# Training Loop
trainer = Trainer(model, lr=5e-5)
losses = []
val_rmses = []
val_rmse_best = np.inf
for epoch in range(10):
train_tasks = gen_tasks(filtered_train_dates, progress=False)
# Log memory usage
print_memory_usage()
batch_losses = []
for task in train_tasks:
optimizer.zero_grad()
with autocast():
loss = model(task)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
batch_losses.append(loss.item())
# Log memory usage
print_memory_usage()
losses.append(np.mean(batch_losses))
rmse = compute_val_rmse(model, val_tasks)
val_rmses.append(rmse)
if not np.isnan(rmse) and rmse < val_rmse_best:
val_rmse_best = rmse
torch.cuda.empty_cache() # Clear cache regularly to prevent OOM
# Prediction
print_memory_usage() # Check GPU memory before prediction
test_task = task_loader("2016-07-19T12:00:00", context_sampling=["all"], seed_override=42)
pred = model.predict(test_task, X_t=anomalies, n_samples=1, ar_sample=True, ar_subsample_factor=20) # Reduced n_samples and ar_subsample_factor to lower memory usage
print_memory_usage() # Check GPU memory after prediction
Closing due to putting the AR sampling on the back-burner
CUDA Out of Memory Error: Recommendations and Debugging Steps
Issue Description
We're encountering a
CUDA out of memory
error while running iyr model training and prediction on GPUs. The error message is as follows:Code
Imports and Setup
Data Preparation
Task Generation & Model Training
Prediction
Recommendations for Reducing CUDA Out of Memory Errors
torch.cuda.memory_summary()
to understand memory usage.torch.cuda.empty_cache()
to clear unused memory spaces.Logging Memory Usage
Reducing Context Sampling Size
Modify the
gen_tasks
function to potentially reduce memory usage:Concepts Highlighted in Code:
Model Complexity:
Mixed Precision Training:
Clear Cache Regularly:
Testing Prediction with Reduced Samples: