ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
16.37k stars 930 forks source link

Memory Leakage Issue in MLX 0.16 #1271

Closed sachinraja13 closed 1 month ago

sachinraja13 commented 1 month ago

Hi Team,

I had used the DETR implementation from the MIMM package here. With MLX 0.15, the memory usage during training remained nearly constant at 11.46 GB. However, after installing MLX 0.16, the memory consumed is growing with every iteration and starts to use swap memory after 8th iteration. Could you please help out to see if there is any memory leak in the latest release?

awni commented 1 month ago

Could you share a snippet to reproduce the memory leak?

sachinraja13 commented 1 month ago

To reproduce:

  1. Clone this repo.
  2. Run the following piece of code:
    
    import os
    import json
    import numpy as np
    import time
    import mlx.core as mx
    import mlx.nn as nn
    import mlx.optimizers as optim
    from mimm.models.detection.detr.utils import NestedTensor
    from mimm.models.detection.detr import detr

def batch_images_and_masks(inputs): images = [] masks = [] original_image_sizes = [] for i in range(len(inputs)): images.append(inputs[i]['image']) masks.append(inputs[i]['mask']) original_image_sizes.append(inputs[i]['orig_size']) images = mx.array(images) masks = mx.array(masks) original_image_sizes = mx.array(original_image_sizes) return images, masks, original_image_sizes

inputs = [] i = {} i['image'] = mx.random.uniform(shape=(1024, 1024, 3)) i['mask'] = mx.zeros(shape=(1024, 1024)) i['orig_size'] = mx.array([1024, 1024]) inputs.append(i) targets = [] t = {} t['boxes'] = mx.random.uniform(shape=(50, 4)) t['labels'] = mx.ones((50), dtype=mx.int64) targets.append(t) images, masks, original_image_sizes = batch_images_and_masks(inputs)

import numpy as np import mlx.core as mx import mlx.nn as nn from scipy.optimize import linear_sum_assignment

class DETRMatcher: def init(self, cost_class: float = 1.0, cost_bbox: float = 5.0, cost_giou: float = 2.0): super().init() self.cost_class = cost_class self.cost_bbox = cost_bbox self.cost_giou = cost_giou

def forward(self, outputs, targets):
    bs, num_queries = outputs['pred_logits'].shape[:2]
    out_prob = nn.softmax(outputs['pred_logits'].flatten(0, 1))  # [batch_size * num_queries, num_classes]
    out_bbox = outputs['pred_boxes'].flatten(0, 1)  # [batch_size * num_queries, 4]
    tgt_ids = mx.concatenate([v['labels'] for v in targets])
    tgt_bbox = mx.concatenate([v['boxes'] for v in targets])
    cost_class = -out_prob[:, tgt_ids]
    cost_bbox = self.compute_l1_distance(out_bbox, tgt_bbox)
    cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))

    C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
    C = np.asarray(C.reshape(bs, num_queries, -1))
    sizes = [len(v['boxes']) for v in targets]
    indices = []
    for i, c in enumerate(C):
        # print(c.shape)
        if i == 0:
            start_index = 0
            end_index = start_index + sizes[i]
        else:
            start_index = sizes[i-1]
            end_index = start_index + sizes[i]
        cost_matrix = c[:, start_index:end_index]
        indices.append(linear_sum_assignment(cost_matrix))
    return [(mx.array(i, dtype=mx.int64), mx.array(j, dtype=mx.int64)) for i, j in indices]

@staticmethod
def compute_l1_distance(src_boxes, tgt_boxes):
    src_boxes = src_boxes[:, None, :]  # [batch_size * num_queries, 1, 4]
    tgt_boxes = tgt_boxes[None, :, :]  # [1, num_targets, 4]
    return mx.sum(mx.abs(src_boxes - tgt_boxes), axis=-1)  # [batch_size * num_queries, num_targets]

def generalized_box_iou(boxes1, boxes2): iou, union = box_iou(boxes1, boxes2) enclose_area = (mx.maximum(boxes1[:, None, 2:], boxes2[:, 2:]) - mx.minimum(boxes1[:, None, :2], boxes2[:, :2])).prod(2) giou = iou - (enclose_area - union) / (enclose_area + mx.array(1e-7)) return giou

def box_iou(boxes1, boxes2): area1 = (boxes1[:, 2] - boxes1[:, 0]) (boxes1[:, 3] - boxes1[:, 1]) area2 = (boxes2[:, 2] - boxes2[:, 0]) (boxes2[:, 3] - boxes2[:, 1])

lt = mx.maximum(boxes1[:, None, :2], boxes2[:, :2])  
rb = mx.minimum(boxes1[:, None, 2:], boxes2[:, 2:])  

wh = mx.clip(rb - lt, 0, None)
inter = wh[:, :, 0] * wh[:, :, 1]

union = area1[:, None] + area2 - inter
iou = inter / union
return iou, union

def box_cxcywh_to_xyxy(x): x_c, y_c, w, h = x[:, 0], x[:, 1], x[:, 2], x[:, 3] b = [x_c - 0.5 w, y_c - 0.5 h, x_c + 0.5 w, y_c + 0.5 h] return mx.stack(b, axis=1)

def box_xyxy_to_cxcywh(x): x0, y0, x1, y1 = x[:, 0], x[:, 1], x[:, 2], x[:, 3] b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)] return mx.stack(b, axis=1)

class DETRLoss: def init(self, num_classes, matcher, weight_dict, eos_coef, losses): super().init() self.num_classes = num_classes self.matcher = matcher self.weight_dict = weight_dict self.eos_coef = eos_coef self.losses = losses self.empty_weight = mx.ones(self.num_classes + 1) self.empty_weight[0] = self.eos_coef

def loss_labels(self, outputs, targets, indices, num_boxes):
    src_logits = outputs['pred_logits']
    idx = self._get_src_permutation_idx(indices)
    target_classes_o = mx.concatenate([t["labels"][J] for t, (_, J) in zip(targets, indices)])
    target_classes = mx.full(src_logits.shape[:2], 0, dtype=mx.int32)
    target_classes[idx] = target_classes_o
    loss_ce = mx.mean(nn.losses.cross_entropy(src_logits, target_classes, weights=self.empty_weight[target_classes]))
    losses = {'loss_ce': loss_ce}

    return losses

def loss_boxes(self, outputs, targets, indices, num_boxes):
    idx = self._get_src_permutation_idx(indices)
    src_boxes = outputs['pred_boxes'][idx]
    target_boxes = mx.concatenate([t['boxes'][i] for t, (_, i) in zip(targets, indices)], axis=0)
    loss_bbox = nn.losses.l1_loss(src_boxes, target_boxes, reduction='none')
    losses = {}
    losses['loss_bbox'] = loss_bbox.sum() * 5.0 / num_boxes
    loss_giou = 1 - mx.diag(generalized_box_iou(
        box_cxcywh_to_xyxy(src_boxes),
        box_cxcywh_to_xyxy(target_boxes)))
    losses['loss_giou'] = loss_giou.sum() * 2.0 / num_boxes

    return losses

def _get_src_permutation_idx(self, indices):
    batch_idx = mx.concatenate([mx.full((src.shape[0],), i, dtype=mx.int64) for i, (src, _) in enumerate(indices)])
    src_idx = mx.concatenate([src for (src, _) in indices])
    return batch_idx, src_idx

def forward(self, outputs, targets):
    indices = self.matcher.forward(outputs, targets)
    num_boxes = sum(len(t["labels"]) for t in targets)
    num_boxes = mx.array([num_boxes], dtype=mx.float32)
    num_boxes = mx.maximum(num_boxes, mx.array([1.0], dtype=mx.float32))
    losses = {}
    for loss in self.losses:
        loss_fn = getattr(self, f'loss_{loss}')
        losses.update(loss_fn(outputs, targets, indices, num_boxes))

    losses = {k: v for k, v in losses.items()}
    return losses

class DETRTrainer: def init(self, model, criterion, optimizer): self.model = model self.criterion = criterion self.optimizer = optimizer

def loss_fn(self, nested_tensor, targets):
    outputs = self.model(nested_tensor)
    loss_dict = self.criterion.forward(outputs, targets)
    loss = sum(loss for loss in loss_dict.values())[0]
    return loss

def train_one_epoch(self, epoch):
    self.model.train()
    total_loss = 0.0
    iteration = 0
    total_iterations = 500
    for i in range(total_iterations):
        images, masks, orig_sizes = batch_images_and_masks(inputs)
        t = NestedTensor(images, masks)
        loss_value, grads = self.loss_and_grad_fn(t, targets)
        grads, total_norm = optim.clip_grad_norm(grads, max_norm=0.1)
        self.optimizer.update(self.model, grads)
        total_loss += loss_value.item()
        if iteration % 1 == 0:
            print(f'Epoch: {epoch}, Iteration: {iteration}, Loss: {loss_value.item()}')
        iteration += 1
        # break
    avg_loss = total_loss / total_iterations
    print(f'Epoch {epoch} completed. Average Loss: {avg_loss}')

def fit(self, num_epochs):
    self.loss_and_grad_fn = nn.value_and_grad(self.model, self.loss_fn)
    for epoch in range(num_epochs):
        self.train_one_epoch(epoch)
        # Optionally, you can add validation here
        # self.evaluate(self.data_loader)

Example usage

if name == "main": mx.set_default_device(mx.gpu) model, processor = detr.detr_resnet50(num_classes=1, num_queries=100, hidden_dim=256) matcher = DETRMatcher() criterion = DETRLoss(num_classes=1, matcher=matcher, weight_dict=None, eos_coef=0.1, losses=["labels", "boxes"]) # Adjust parameters as needed lr_schedule = optim.step_decay(0.0001, 0.5, 2000) optimizer = optim.AdamW(learning_rate=lr_schedule) trainer = DETRTrainer(model, criterion, optimizer) trainer.fit(num_epochs=10)



In the activity monitor, you will see the memory used by the python process jumping very quickly. This was not the case with MLX 0.15. I did not make any changes to the code in the MIMM package.
awni commented 1 month ago

Cool, thanks for sharing that. I can indeed reproduce it. It looks like it was introduced between 0.15.2 and 0.16. I will bisect and look for the issue.

awni commented 1 month ago

It started showing up after https://github.com/ml-explore/mlx/pull/1246.

sachinraja13 commented 1 month ago

Many thanks for your help with this.

awni commented 1 month ago

So there does seem to be a leak of some sort. However, you if you eval the model and optimizer state after each iteration (which is good practice) then it seems like the leak goes away. So I would recommend that you do that:

            self.optimizer.update(self.model, grads)
            mx.eval(self.model, self.optimizer.state)  # <-- add that line
            total_loss += loss_value.item()
awni commented 1 month ago

A simple repro which shows the leak:

import mlx.core as mx

if __name__ == "__main__":

    params = [mx.zeros((256, 256))]

    def loss_fn(params, x):
        x = x @ params[0].T
        outputs = mx.fast.layer_norm(x, None, None, 1e-4)
        mx.eval(outputs)
        return outputs.sum()

    loss_and_grad_fn = mx.value_and_grad(loss_fn)

    src = mx.random.uniform(shape=(1, 32, 32, 256))
    for i in range(50000):
        loss_value, grads = loss_and_grad_fn(params, src)
        params[0] = grads[0] + params[0]
        print(mx.metal.get_peak_memory())
sachinraja13 commented 1 month ago

So there does seem to be a leak of some sort. However, you if you eval the model and optimizer state after each iteration (which is good practice) then it seems like the leak goes away. So I would recommend that you do that:

            self.optimizer.update(self.model, grads)
            mx.eval(self.model, self.optimizer.state)  # <-- add that line
            total_loss += loss_value.item()

Thank you, @awni! This solved the problem.

awni commented 1 month ago

This is mostly closed by #1274