Closed samuelstevens closed 5 months ago
@samuelstevens how many dataloader workers did you specify? You don't want to go overboard with workers.
The callstack isn't very helpful, which is typical. But given that it looks like it may being killed externally, are you running out of system memory? When you're running this, run htop / top / free etc and see what the system memory use looks like. Also check dmesg, syslogs if the oom killer was active.
8 workers per process with 8 GPUs, so 64 total. I agree that it is likely not actually an S3 error. It's running on MosaicML's cloud platform with 8xH100s. I will try with some memory logging tools and dig into dmesg and syslogs to see if I can establish that running out of RAM is the issue and will report back. Thanks for the tips.
@samuelstevens yeah, that's nothing crazy so should be okay, although I find 4-6 workers usually enough with basic pre-proc and the larger models, but more can help keep utilization higher if the models are smaller.
There can also be shared mem issues running containerized (shmem flags needed), but I think that usually results in a more specific error. Slurm or other job managers might also set memory limits via cgroups that could kill greedy processes. Not sure what sort of setup that cloud platform is running...
I can confirm that the evaluate()
function in training/train.py
kills it. I think it's running out of system RAM. I have ~96K validation examples, which I don't think is too many, but I notice a FIXME comment in this function that's warning about memory usage.
This is the snippet (copied here to make it easier):
# FIXME this does not scale past small eval datasets
# all_image_features @ all_text_features will blow up memory and compute very quickly
cumulative_loss = 0.0
cumulative_gen_loss = 0.0
all_image_features, all_text_features = [], []
with torch.no_grad():
for i, batch in enumerate(dataloader):
images, texts = batch
images = images.to(device=device, dtype=input_dtype, non_blocking=True)
texts = texts.to(device=device, non_blocking=True)
with autocast():
model_out = model(images, texts)
image_features = model_out["image_features"]
text_features = model_out["text_features"]
logit_scale = model_out["logit_scale"]
# features are accumulated in CPU tensors, otherwise GPU memory exhausted quickly
# however, system RAM is easily exceeded and compute time becomes problematic
all_image_features.append(image_features.cpu())
all_text_features.append(text_features.cpu())
And then these all_image_features
are used in get_clip_metrics
, which is what causes the OOM:
val_metrics = get_clip_metrics(
image_features=torch.cat(all_image_features),
text_features=torch.cat(all_text_features),
logit_scale=logit_scale.cpu(),
)
def get_clip_metrics(image_features, text_features, logit_scale):
metrics = {}
logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu()
logits_per_text = logits_per_image.t().detach().cpu()
logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text}
ground_truth = torch.arange(len(text_features)).view(-1, 1)
for name, logit in logits.items():
ranking = torch.argsort(logit, descending=True)
preds = torch.where(ranking == ground_truth)[1]
preds = preds.detach().cpu().numpy()
metrics[f"{name}_mean_rank"] = preds.mean() + 1
metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1
for k in [1, 5, 10]:
metrics[f"{name}_R@{k}"] = np.mean(preds < k)
return metrics
I guess it has to create two 96K x 96K matrices (logits_per_image
and logits_per_text
) which are about 18GB each, plus whatever intermediate memory it needs for these calculations.
As I understand it, this function calculates img2txt and txt2img scores among all examples in the validation set. How important are these metrics vs just tracking validation loss per batch and averaging at the end of the epoch, like we would in supervised learning? Do you find that these img2txt and txt2img retrieval scores are significantly more informative?
Debugging here in the hopes that it will eventually help someone else.
I used dmesg -T
to track memory usage. I find that there are 136 instances of python running: I assume it's 8 for the 8 main training processes, then 8 x 8 = 64 for each of the dataloaders (val and train) for a total of 136 processes.
When I sum the rss
column of the table produced by dmesg and multiply by my page size (getconf PAGESIZE
), I find that the dataloader workers are using 859GB and the system only has 945GB.
This is enough for me to think that I have too many dataloader workers and that it's always the dreaded oom-killer that's causing these issues. Thanks @rwightman for the suggestion to use dmesg--that was really helpful and was a tool I hadn't used before.
While this has certainly mitigated the issue, I still am seeing increasing memory usage in wandb logs (see screenshot). I suspect there is a memory leak somewhere. @rwightman are you aware of any common culprits for memory leaks? My complete command is:
composer \
-m training.main \
--train-data "pipe:aws s3 cp s3://bucket-name/directory/shard-{000000..002549}.tar - || true" \
--train-num-samples 12000000 \
--val-data "pipe:aws s3 cp s3://bucket-name/directory/shard-{002540..002559}.tar - || true" \
--val-num-samples 95702 \
--dataset-resampled \
--dataset-type webdataset \
--model ViT-B-16 \
--pretrained openai \
--batch-size 4096 \
--epochs 40 \
--log-every-n-steps 10 \
--eps 1e-6 \
--lr 5e-4 \
--local-loss \
--gather-with-grad \
--grad-checkpointing \
--precision amp \
--workers 4 \
--seed 42 \
--warmup 1000 \
--report-to wandb \
--wandb-project-name clip \
--logs /logs \
--remote-sync s3://bucket-name/checkpoints
Could it be the --remote-sync
option?
@samuelstevens hmm, memory leaks always fun to pin down. I haven't used remote-sync option myself extensively so not sure if there's something there. You're sure it's not usual memory churn & fragmentation in the beginning? ie it never hits a stable level after some time (even if that level is quite high)?
Yeah total python processes will be num_gpu + num_dl_workers * num_gpu. There are actually going to be processes launched for each aws s3 pipe too, those will scale with number of dl workers but they're just streaming bytes so shouldn't allocate much.
I usually patch the memory allocator for this sort of training, there are allocators that behave better than defaults for use cases like this, less waste, better performance under load.
sudo apt install google-perftools
In your train env, export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4
before launching the train process (or wherever that lib ends up, varies a bit by distro), hooks malloc
And yeah, that eval was intended for doing a full sample-to-sample comparison, so it needs to remain a small set of samples for sanity in the matrix size. If avg loss, or avg metrics over batches is desired, that's possible and not unreasonable but isn't setup that way right now.
I typically used a zero-shot eval loss for tracking performance during train...
Oops, forgot the screenshot. Looks very clearly to be a memory leak that never stabilizes (this is 40 epochs). I will try without the --remote-sync
and see if that's the issue.
With respect to the eval code, I will probably just use validation loss. Thanks for the insight.
Closing this because we debugged the error to be OOM. If the memory leak is an issue I will open a new issue. Thanks!
s3://bucket-name/directory/shard-{000000..002000}.tar
.--train-data "pipe:aws s3 cp s3://bucket-name/directory/shard-{000000..000256}.tar - || true"
for my download command.The logs look like this:
It seems like
aws s3 cp
fails but I can successfully download shards on my cluster. Are there any suggestions for how to handle this?