ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
16.71k stars 958 forks source link

Training seg faults after a few iterations #341

Open bolipop opened 9 months ago

bolipop commented 9 months ago

Hello, I'm trying to train a simple network (mobilenet classifier) which seems fine but I'm getting a segfault after a few batches. Hoping maybe someone can point out what I'm doing wrong or some pointers to debug the seg fault since it just errors out with no decent traceback. Thanks!

Macbook Pro M2 Max 32GB

import itertools

import mlx.core as mx
import numpy as np
import mlx.nn as nn
import mlx.optimizers as optim
from datasets import load_dataset
import cv2
from tqdm import tqdm

class DSConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True):
        super().__init__()

        self.depth = [
            nn.Conv2d(1, 1, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
            for _ in range(in_channels)
        ]
        self.bn1 = nn.LayerNorm(in_channels)
        self.point = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=True)
        self.bn2 = nn.LayerNorm(out_channels)

    def __call__(self, x):
        x = x.split(x.shape[-1], axis=-1)  # Split across channels
        depth = mx.concatenate([l(_x) for l, _x in zip(self.depth, x)], axis=-1)
        point = self.point(nn.relu(self.bn1(depth)))
        return nn.relu(self.bn2(point))

class MobileNet(nn.Module):
    def __init__(self, input_channels, num_classes, slim: bool = False):
        super().__init__()

        self.input_conv = nn.Conv2d(input_channels, 32, kernel_size=3, stride=2, padding=1)
        self.bn1 = nn.LayerNorm(32)

        layers = [
            DSConv(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
            DSConv(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1),
            DSConv(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
            DSConv(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1),
            DSConv(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
            DSConv(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1),
        ]
        if not slim:
            for _ in range(5):
                layers += [
                    DSConv(512, 512, kernel_size=3, stride=1, padding=1),
                ]

        layers += [
            DSConv(512, 1024, kernel_size=3, stride=2, padding=1),
            DSConv(1024, 1024, kernel_size=3, stride=2, padding=4),
        ]
        self.layers = nn.Sequential(*layers)

        self.linear = nn.Linear(1024, num_classes)

    def __call__(self, x):
        x = nn.relu(self.bn1(self.input_conv(x)))
        x = self.layers(x)
        x = mx.mean(x, axis=[1, 2]).reshape(x.shape[0], -1)
        x = self.linear(x)

        return x

def grouper(iterable, n, *, incomplete="fill", fillvalue=None):
    "Collect data into non-overlapping fixed-length chunks or blocks"

    args = [iter(iterable)] * n
    match incomplete:
        case "fill":
            return itertools.zip_longest(*args, fillvalue=fillvalue)
        case "strict":
            return zip(*args, strict=True)
        case "ignore":
            return zip(*args)
        case _:
            raise ValueError("Expected fill, strict, or ignore")

def collate(rows):
    # resize with openCV
    _images = [np.array(item["image"]) for item in rows]
    _images = [cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) if len(im.shape) == 2 else im for im in _images]
    _images = [cv2.resize(im, (224, 224), interpolation=cv2.INTER_CUBIC) for im in _images]
    images = np.array(_images, dtype=np.float32)
    images /= 255.0

    labels = np.array([item["label"] for item in rows], dtype=np.uint32)
    return mx.array(images), mx.array(labels)  # ( (b,h,w,c), (b,c) )

def loss_fn(model, X, y):
    return mx.mean(nn.losses.cross_entropy(model(X), y))

def eval_fn(model, X, y):
    return mx.mean(mx.argmax(model(X), axis=1) == y)

def main():
    batch_size = 2

    model = MobileNet(input_channels=3, num_classes=1000, slim=True)
    mx.eval(model.parameters())

    loss_helper = nn.value_and_grad(model, loss_fn)
    optimizer = optim.SGD(learning_rate=0.1)

    datasets = load_dataset("imagenet-1k", trust_remote_code=True)
    for rows in tqdm(grouper(datasets["train"], batch_size, fillvalue="ignore")):
        images, labels = collate(rows)
        _loss, grads = loss_helper(model, images, labels)
        optimizer.update(model, grads)
        mx.eval(model.parameters(), optimizer.state)

if __name__ == "__main__":
    main()

21it [00:11, 1.81it/s]zsh: segmentation fault python3 mobilenet/main.py /Users/bento/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/resource_tracker.py:254: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown warnings.warn('resource_tracker: There appear to be %d '

awni commented 9 months ago

Hmm. One thing I'm wondering is if you can try just looping over the data without using MLX. Just to make sure this is an MLX issue and not something to do with the datasets package you are using.

Also good to monitor your memory as you do so see if there is a leak or if you are using way too much. (Use activity monitor or asitop).

bolipop commented 9 months ago

Hmm. One thing I'm wondering is if you can try just looping over the data without using MLX. Just to make sure this is an MLX issue and not something to do with the datasets package you are using.

Also good to monitor your memory as you do so see if there is a leak or if you are using way too much. (Use activity monitor or asitop).

When I uncomment these lines, I'm able to loop through the entire dataset just fine.

_loss, grads = loss_helper(model, images, labels)
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state)

Looking at the memory usage, I suspect it's due to out of memory

Screenshot 2024-01-01 at 11 06 15 PM
awni commented 9 months ago

It doesn't look to be out of memory. And it definitely shouldn't segfault. Does it segfault reliably for you? How far into the training?

awni commented 9 months ago

I'm running your script on an M1 Max with 32 GB. So far no segfault 🤷‍♂️ , I'm at iteration 600. Did it segfault before that?

Also what's your OS? What version of MLX are you using? (Commit hash if from source?)

bolipop commented 9 months ago

Sonoma 14.2.1 M2 Max 32 GB Python 3.11.7

Yeah, I've had it segfault right away before, it's very sporadic. Sometimes it just hangs and I have to go and kill the process manually.

mlx ❯ python3 mobilenet/train.py
9it [00:04,  2.30it/s]zsh: segmentation fault  python3 mobilenet/train.py
/Users/bento/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/resource_tracker.py:254: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

~/Repos/mlx-stuff/mlx-playground main* 8s
mlx ❯ python3 mobilenet/train.py
69it [01:18,  2.02it/s]zsh: segmentation fault  python3 mobilenet/train.py
/Users/bento/.pyenv/versions/3.11.7/lib/python3.11/multiprocessing/resource_tracker.py:254: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

~/Repos/mlx-stuff/mlx-playground main* 1m 22s
mlx ❯ python3 mobilenet/train.py
12it [00:05,  2.29it/s]
bolipop commented 9 months ago

It doesn't look to be out of memory. And it definitely shouldn't segfault. Does it segfault reliably for you? How far into the training?

you're right, I thought the little widget on the right was tracking memory

awni commented 9 months ago

What about your MLX version (or commit hash if building from source)?

bolipop commented 9 months ago

0.0.6

bolipop commented 9 months ago

Not sure if it helps but earlier I saw a bus error instead of a seg fault.

angeloskath commented 9 months ago

I can't reproduce it either. I left it running for about an hour on my M2 air. My initial thought was that it had to do with the implementation of separable convolution which ends up having 1000 layers and concatenating 1000 arrays but it doesn't seem to cause a problem at all.

awni commented 1 month ago

@bolipop If you could run it again and let us know if it still defaults for you with the latest MLX that would useful. I believe we’ve fixed the underlying issue but hard to be sure since we never reproduced this exact one