Closed sachinraja13 closed 1 month ago
Could you share a snippet to reproduce the memory leak?
To reproduce:
import os
import json
import numpy as np
import time
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from mimm.models.detection.detr.utils import NestedTensor
from mimm.models.detection.detr import detr
def batch_images_and_masks(inputs): images = [] masks = [] original_image_sizes = [] for i in range(len(inputs)): images.append(inputs[i]['image']) masks.append(inputs[i]['mask']) original_image_sizes.append(inputs[i]['orig_size']) images = mx.array(images) masks = mx.array(masks) original_image_sizes = mx.array(original_image_sizes) return images, masks, original_image_sizes
inputs = [] i = {} i['image'] = mx.random.uniform(shape=(1024, 1024, 3)) i['mask'] = mx.zeros(shape=(1024, 1024)) i['orig_size'] = mx.array([1024, 1024]) inputs.append(i) targets = [] t = {} t['boxes'] = mx.random.uniform(shape=(50, 4)) t['labels'] = mx.ones((50), dtype=mx.int64) targets.append(t) images, masks, original_image_sizes = batch_images_and_masks(inputs)
import numpy as np import mlx.core as mx import mlx.nn as nn from scipy.optimize import linear_sum_assignment
class DETRMatcher: def init(self, cost_class: float = 1.0, cost_bbox: float = 5.0, cost_giou: float = 2.0): super().init() self.cost_class = cost_class self.cost_bbox = cost_bbox self.cost_giou = cost_giou
def forward(self, outputs, targets):
bs, num_queries = outputs['pred_logits'].shape[:2]
out_prob = nn.softmax(outputs['pred_logits'].flatten(0, 1)) # [batch_size * num_queries, num_classes]
out_bbox = outputs['pred_boxes'].flatten(0, 1) # [batch_size * num_queries, 4]
tgt_ids = mx.concatenate([v['labels'] for v in targets])
tgt_bbox = mx.concatenate([v['boxes'] for v in targets])
cost_class = -out_prob[:, tgt_ids]
cost_bbox = self.compute_l1_distance(out_bbox, tgt_bbox)
cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
C = np.asarray(C.reshape(bs, num_queries, -1))
sizes = [len(v['boxes']) for v in targets]
indices = []
for i, c in enumerate(C):
# print(c.shape)
if i == 0:
start_index = 0
end_index = start_index + sizes[i]
else:
start_index = sizes[i-1]
end_index = start_index + sizes[i]
cost_matrix = c[:, start_index:end_index]
indices.append(linear_sum_assignment(cost_matrix))
return [(mx.array(i, dtype=mx.int64), mx.array(j, dtype=mx.int64)) for i, j in indices]
@staticmethod
def compute_l1_distance(src_boxes, tgt_boxes):
src_boxes = src_boxes[:, None, :] # [batch_size * num_queries, 1, 4]
tgt_boxes = tgt_boxes[None, :, :] # [1, num_targets, 4]
return mx.sum(mx.abs(src_boxes - tgt_boxes), axis=-1) # [batch_size * num_queries, num_targets]
def generalized_box_iou(boxes1, boxes2): iou, union = box_iou(boxes1, boxes2) enclose_area = (mx.maximum(boxes1[:, None, 2:], boxes2[:, 2:]) - mx.minimum(boxes1[:, None, :2], boxes2[:, :2])).prod(2) giou = iou - (enclose_area - union) / (enclose_area + mx.array(1e-7)) return giou
def box_iou(boxes1, boxes2): area1 = (boxes1[:, 2] - boxes1[:, 0]) (boxes1[:, 3] - boxes1[:, 1]) area2 = (boxes2[:, 2] - boxes2[:, 0]) (boxes2[:, 3] - boxes2[:, 1])
lt = mx.maximum(boxes1[:, None, :2], boxes2[:, :2])
rb = mx.minimum(boxes1[:, None, 2:], boxes2[:, 2:])
wh = mx.clip(rb - lt, 0, None)
inter = wh[:, :, 0] * wh[:, :, 1]
union = area1[:, None] + area2 - inter
iou = inter / union
return iou, union
def box_cxcywh_to_xyxy(x): x_c, y_c, w, h = x[:, 0], x[:, 1], x[:, 2], x[:, 3] b = [x_c - 0.5 w, y_c - 0.5 h, x_c + 0.5 w, y_c + 0.5 h] return mx.stack(b, axis=1)
def box_xyxy_to_cxcywh(x): x0, y0, x1, y1 = x[:, 0], x[:, 1], x[:, 2], x[:, 3] b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)] return mx.stack(b, axis=1)
class DETRLoss: def init(self, num_classes, matcher, weight_dict, eos_coef, losses): super().init() self.num_classes = num_classes self.matcher = matcher self.weight_dict = weight_dict self.eos_coef = eos_coef self.losses = losses self.empty_weight = mx.ones(self.num_classes + 1) self.empty_weight[0] = self.eos_coef
def loss_labels(self, outputs, targets, indices, num_boxes):
src_logits = outputs['pred_logits']
idx = self._get_src_permutation_idx(indices)
target_classes_o = mx.concatenate([t["labels"][J] for t, (_, J) in zip(targets, indices)])
target_classes = mx.full(src_logits.shape[:2], 0, dtype=mx.int32)
target_classes[idx] = target_classes_o
loss_ce = mx.mean(nn.losses.cross_entropy(src_logits, target_classes, weights=self.empty_weight[target_classes]))
losses = {'loss_ce': loss_ce}
return losses
def loss_boxes(self, outputs, targets, indices, num_boxes):
idx = self._get_src_permutation_idx(indices)
src_boxes = outputs['pred_boxes'][idx]
target_boxes = mx.concatenate([t['boxes'][i] for t, (_, i) in zip(targets, indices)], axis=0)
loss_bbox = nn.losses.l1_loss(src_boxes, target_boxes, reduction='none')
losses = {}
losses['loss_bbox'] = loss_bbox.sum() * 5.0 / num_boxes
loss_giou = 1 - mx.diag(generalized_box_iou(
box_cxcywh_to_xyxy(src_boxes),
box_cxcywh_to_xyxy(target_boxes)))
losses['loss_giou'] = loss_giou.sum() * 2.0 / num_boxes
return losses
def _get_src_permutation_idx(self, indices):
batch_idx = mx.concatenate([mx.full((src.shape[0],), i, dtype=mx.int64) for i, (src, _) in enumerate(indices)])
src_idx = mx.concatenate([src for (src, _) in indices])
return batch_idx, src_idx
def forward(self, outputs, targets):
indices = self.matcher.forward(outputs, targets)
num_boxes = sum(len(t["labels"]) for t in targets)
num_boxes = mx.array([num_boxes], dtype=mx.float32)
num_boxes = mx.maximum(num_boxes, mx.array([1.0], dtype=mx.float32))
losses = {}
for loss in self.losses:
loss_fn = getattr(self, f'loss_{loss}')
losses.update(loss_fn(outputs, targets, indices, num_boxes))
losses = {k: v for k, v in losses.items()}
return losses
class DETRTrainer: def init(self, model, criterion, optimizer): self.model = model self.criterion = criterion self.optimizer = optimizer
def loss_fn(self, nested_tensor, targets):
outputs = self.model(nested_tensor)
loss_dict = self.criterion.forward(outputs, targets)
loss = sum(loss for loss in loss_dict.values())[0]
return loss
def train_one_epoch(self, epoch):
self.model.train()
total_loss = 0.0
iteration = 0
total_iterations = 500
for i in range(total_iterations):
images, masks, orig_sizes = batch_images_and_masks(inputs)
t = NestedTensor(images, masks)
loss_value, grads = self.loss_and_grad_fn(t, targets)
grads, total_norm = optim.clip_grad_norm(grads, max_norm=0.1)
self.optimizer.update(self.model, grads)
total_loss += loss_value.item()
if iteration % 1 == 0:
print(f'Epoch: {epoch}, Iteration: {iteration}, Loss: {loss_value.item()}')
iteration += 1
# break
avg_loss = total_loss / total_iterations
print(f'Epoch {epoch} completed. Average Loss: {avg_loss}')
def fit(self, num_epochs):
self.loss_and_grad_fn = nn.value_and_grad(self.model, self.loss_fn)
for epoch in range(num_epochs):
self.train_one_epoch(epoch)
# Optionally, you can add validation here
# self.evaluate(self.data_loader)
if name == "main": mx.set_default_device(mx.gpu) model, processor = detr.detr_resnet50(num_classes=1, num_queries=100, hidden_dim=256) matcher = DETRMatcher() criterion = DETRLoss(num_classes=1, matcher=matcher, weight_dict=None, eos_coef=0.1, losses=["labels", "boxes"]) # Adjust parameters as needed lr_schedule = optim.step_decay(0.0001, 0.5, 2000) optimizer = optim.AdamW(learning_rate=lr_schedule) trainer = DETRTrainer(model, criterion, optimizer) trainer.fit(num_epochs=10)
In the activity monitor, you will see the memory used by the python process jumping very quickly. This was not the case with MLX 0.15. I did not make any changes to the code in the MIMM package.
Cool, thanks for sharing that. I can indeed reproduce it. It looks like it was introduced between 0.15.2 and 0.16. I will bisect and look for the issue.
It started showing up after https://github.com/ml-explore/mlx/pull/1246.
Many thanks for your help with this.
So there does seem to be a leak of some sort. However, you if you eval the model and optimizer state after each iteration (which is good practice) then it seems like the leak goes away. So I would recommend that you do that:
self.optimizer.update(self.model, grads)
mx.eval(self.model, self.optimizer.state) # <-- add that line
total_loss += loss_value.item()
A simple repro which shows the leak:
import mlx.core as mx
if __name__ == "__main__":
params = [mx.zeros((256, 256))]
def loss_fn(params, x):
x = x @ params[0].T
outputs = mx.fast.layer_norm(x, None, None, 1e-4)
mx.eval(outputs)
return outputs.sum()
loss_and_grad_fn = mx.value_and_grad(loss_fn)
src = mx.random.uniform(shape=(1, 32, 32, 256))
for i in range(50000):
loss_value, grads = loss_and_grad_fn(params, src)
params[0] = grads[0] + params[0]
print(mx.metal.get_peak_memory())
So there does seem to be a leak of some sort. However, you if you eval the model and optimizer state after each iteration (which is good practice) then it seems like the leak goes away. So I would recommend that you do that:
self.optimizer.update(self.model, grads) mx.eval(self.model, self.optimizer.state) # <-- add that line total_loss += loss_value.item()
Thank you, @awni! This solved the problem.
This is mostly closed by #1274
Hi Team,
I had used the DETR implementation from the MIMM package here. With MLX 0.15, the memory usage during training remained nearly constant at 11.46 GB. However, after installing MLX 0.16, the memory consumed is growing with every iteration and starts to use swap memory after 8th iteration. Could you please help out to see if there is any memory leak in the latest release?