ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
17.15k stars 992 forks source link

Issue encountered in solving 2D Heat Equation with needing mx.eval to avoid segmentation fault #101

Open sck-at-ucy opened 11 months ago

sck-at-ucy commented 11 months ago

I have implemented a simple solution of the 2D Heat Conduction Equation with 2 Neumann and 2 Dirichlet BCs. I have the code implemented both using PyTorch and the MLX framework and I am testing the relative performance on an M2 Ultra with 128GB memory.

The MLX code is included below. So far, performance in various tests (on the same machine) show the MLX version to be somewhere between X2 and X10 faster depending on the problem size.

However, I have an issue that I need to understand. Depending on the problem size, I need to include the line

if step % 15000== 0: mx.eval(T)

to avoid segmentation fault. I imagine this has to do with the lazy evaluation and arrays being in buffer? My issue is that currently I figure each time how often I need to mx.eval empirically. Is there some programmatic and more elegant way to automatically issue the mx.eval at the right frequency based on the problem size?

Here is the complete code below. Thank you for all your help @awni !

# Solving the 2D Heat Conduction Equation with 2 Neumann and 2 Dirichlet PCs
import numpy as np
import matplotlib.pyplot as plt
import time
import mlx.core as mx

# Convergence tolerance to stop early (currently disabled)
#convergence_tolerance = 1e-8

# Grid size and material properties setup
nx, ny = 5000, 5000  # Set grid dimensions
k = 1.0              # Thermal conductivity

# Time-stepping parameters
desired_dt = 0.01  # Desired time step
max_steps = 10000 # Maximum number of time steps

# Creating a linearly spaced grid
x = mx.array(np.linspace(0,1,nx))
y = mx.array(np.linspace(0,1,ny))
dx = x[1] - x[0]   # Grid spacing in x direction
dy = y[1] - y[0]   # Grid spacing in y direction

# Function to calculate the maximum stable time step for the explicit Euler method
def calculate_max_stable_dt(alpha, dx, dy):
    return (1 / (2 * alpha)) * (1 / (1/dx**2 + 1/dy**2))

# Material properties for stability calculation
rho = 1.0  # Density
cp = 1.0   # Specific heat capacity
alpha = k / (rho * cp)  # Thermal diffusivity

# Compute maximum stable time step
dt_max = calculate_max_stable_dt(alpha, dx, dy)
dt = min(dt_max, desired_dt)  # Use the smaller of the desired or maximum stable time step

# Initializing the temperature field on the GPU
T = mx.zeros([nx, ny])
T_old = mx.zeros_like(T)

# Applying Dirichlet boundary conditions
T[:, 0] = 0.0   # Set left boundary temperature
T[:, -1] = 1.0  # Set right boundary temperature

# Time-stepping loop for the heat equation

start_time = time.time()  # Capture start time
for step in range(max_steps):
    T_old = mx.broadcast_to(T,shape=T.shape)

    # Update interior points using finite difference method
    # Pad the interior points for broadcasting
    T =  mx.pad(mx.pad( (T_old[1:-1,1:-1] + dt * k * (
        (T_old[2:, 1:-1] - 2 * T_old[1:-1, 1:-1] + T_old[:-2, 1:-1]) / dx**2 +
        (T_old[1:-1, 2:] - 2 * T_old[1:-1, 1:-1] + T_old[1:-1, :-2]) / dy**2
    )), ((0,0),(0,1)),1),((0,0),(1,0)), 0)

    # Update Neumann boundaries (zero-flux) at top and bottom
    T = mx.concatenate([mx.expand_dims(T[0, :], (0)), T, mx.expand_dims(T[-1, :], (-0))], axis=0)

    if step % 15000== 0:
        mx.eval(T)

end_time = time.time()  # Capture end time
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time:.2f} seconds")

#  Visualizing the temperature field using matplotlib
plt.imshow(T, cmap='hot', interpolation='nearest')
plt.colorbar()  # Add a color bar to indicate temperature scales
plt.show()
awni commented 11 months ago

It should never segfault, so that's not something you are doing sub-optimally with eval. It looks like a bug to me.. but we'll have to investigate further.

sck-at-ucy commented 11 months ago

Ok knowing that, I will try to do some debugging and see if I can provide some additional feedback.

awni commented 11 months ago

So for some reason it segfaults if I don't put the mx.eval(T) at the beginning (e.g. before starting the loop). That definitely smells like a bug. But otherwise you don't need an eval until the end.

Also I made a couple of slight modifications, it should be noticeably faster (assuming I didn't mess anything up).

# Solving the 2D Heat Conduction Equation with 2 Neumann and 2 Dirichlet PCs
import numpy as np
import time
import mlx.core as mx

# Grid size and material properties setup
nx, ny = 5000, 5000  # Set grid dimensions
k = 1.0              # Thermal conductivity

# Time-stepping parameters
desired_dt = 0.01  # Desired time step
max_steps = 10000 # Maximum number of time steps

# Creating a linearly spaced grid
x = mx.array(np.linspace(0,1,nx))
y = mx.array(np.linspace(0,1,ny))
dx = x[1] - x[0]   # Grid spacing in x direction
dy = y[1] - y[0]   # Grid spacing in y direction

# Function to calculate the maximum stable time step for the explicit Euler method
def calculate_max_stable_dt(alpha, dx, dy):
    return (1 / (2 * alpha)) * (1 / (1/dx**2 + 1/dy**2))

# Material properties for stability calculation
rho = 1.0  # Density
cp = 1.0   # Specific heat capacity
alpha = k / (rho * cp)  # Thermal diffusivity

# Compute maximum stable time step
dt_max = calculate_max_stable_dt(alpha, dx, dy)
dt = min(dt_max, desired_dt)  # Use the smaller of the desired or maximum stable time step

# Initializing the temperature field on the GPU
T = mx.zeros([nx, ny])
T_old = mx.zeros_like(T)

# Applying Dirichlet boundary conditions
T[:, 0] = 0.0   # Set left boundary temperature
T[:, -1] = 1.0  # Set right boundary temperature

# Time-stepping loop for the heat equation

dysq = dy**2
dxsq = dx**2
mx.eval(T)

start_time = time.time()  # Capture start time
for step in range(max_steps):
    # Update interior points using finite difference method
    # Pad the interior points for broadcasting
    T_mid = T_old[1:-1,1:-1]
    T =  mx.pad(mx.pad((T_mid + dt * k * (
        (T_old[2:, 1:-1] - 2 * T_mid + T_old[:-2, 1:-1]) / dxsq +
        (T_old[1:-1, 2:] - 2 * T_mid + T_old[1:-1, :-2]) / dysq
    )), ((0,0),(0,1)),1),((0,0),(1,0)), 0)

    # Update Neumann boundaries (zero-flux) at top and bottom
    T = mx.concatenate([mx.expand_dims(T[0, :], (0)), T, mx.expand_dims(T[-1, :], (-0))], axis=0)

mx.eval(T)
end_time = time.time()  # Capture end time
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time:.2f} seconds")
sck-at-ucy commented 11 months ago

It runs orders of magnitude faster (438 sec vs 2.37 sec for 100K steps) !!! And whatever the bug is, it's not surfacing any more. I'm thrilled with the speedup πŸ‘

sck-at-ucy commented 11 months ago

I was too quick to respond. The T_old temperature field (see the code below) needs to be updated in each time step, that's why I had the line T_old = mx.broadcast_to(T,shape=T.shape) at the beginning of each time step, which I found was a bit faster than setting T_old=T. Once I include again this line the problem resurfaces and the speeds go back to more or less what I was having before.

T_mid = T_old[1:-1,1:-1] T = mx.pad(mx.pad((T_mid + dt * k * ( (T_old[2:, 1:-1] - 2 * T_mid + T_old[:-2, 1:-1]) / dxsq + (T_old[1:-1, 2:] - 2 * T_mid + T_old[1:-1, :-2]) / dysq )), ((0,0),(0,1)),1),((0,0),(1,0)), 0)

awni commented 11 months ago

Oops, sorry about that!

It's very odd that doing T_old = mx.broadcast_to(T,shape=T.shape) is faster than T_old=T πŸ€”. More to investigate here.

sck-at-ucy commented 11 months ago

Following your remark, I tested again T_old = mx.broadcast_to(T,shape=T.shape) and T_old=T and they are about the same (within load noise). I must have been misled by some other change I made in parallel before. So, this does not be to checked. Sorry for the wrong claim 🫣

sck-at-ucy commented 11 months ago

Ok, I did some basic debugging using the debugger in PyCharm (PyDev). I inserted a breakpoint at the iteration step that I would normally need to execute mx.aval(T) to avoid segfault. Stepping though all values in the arrays look normal and nothing weird happens. However, when I make the fist step forward from the break-point it takes a long time for variable values to load on the debugger. I can step through the code and complete execution correctly. I am not sure, if this adds any useful hints. Suspicion: Am I effectively, using an eval by loading the values in the debugger and that's why the code completes correctly without need the explicit mx.aval(T) πŸ€”? To my non-expert eyes this looks like a low-level error in MXL as there is nothing funny happening in the actual array values. Probably something with the dynamic graphs and memory ?

awni commented 11 months ago

I'm not able to reproduce the segfault. Does it segfault for you using the settings you have in the example above?

Also what hardware / OS are you using?

sck-at-ucy commented 11 months ago

Yes, but below I also provide the latest version (have added your suggestions) which still segfaults if I don't mx.eval(T) or mx.eval(T_old) at step % 150000 == 0. The exact frequency of needing to 'mx.eval' depends on the problem size (nx,ny).

Hardware

Model Name: Mac Studio Model Identifier: Mac14,14 Model Number: Z180000DEB/A Chip: Apple M2 Ultra Total Number of Cores: 24 (16 performance and 8 efficiency) Memory: 128 GB System Firmware Version: 10151.41.12 OS Loader Version: 10151.41.12

OS Version

macOS Sonoma Version 14.1.2

Python Environment

Python 3.11.6 numpy 1.26.2 mlx 0.0.4 Conda 23.11.0 libblas 3.9.0 20_osxarm64_accelerate conda-forge libcblas 3.9.0 20_osxarm64_accelerate conda-forge liblapack 3.9.0 20_osxarm64_accelerate conda-forge

Segfault Message

Elapsed time: 2.99 seconds

Process finished with exit code 139 (interrupted by signal 11:SIGSEGV)

Current version of code

# Solving the 2D Heat Conduction Equation with 2 Neumann and 2 Dirichlet PCs
import numpy as np
import time
import mlx.core as mx

# Grid size and material properties setup
nx, ny = 5000, 5000  # Set grid dimensions
k = 1.0              # Thermal conductivity

# Time-stepping parameters
desired_dt = 0.01  # Desired time step
max_steps = 50000 # Maximum number of time steps

# Creating a linearly spaced grid
x = mx.array(np.linspace(0,1,nx))
y = mx.array(np.linspace(0,1,ny))
dx = x[1] - x[0]   # Grid spacing in x direction
dy = y[1] - y[0]   # Grid spacing in y direction

# Function to calculate the maximum stable time step for the explicit Euler method
def calculate_max_stable_dt(alpha, dx, dy):
    return (1 / (2 * alpha)) * (1 / (1/dx**2 + 1/dy**2))

# Material properties for stability calculation
rho = 1.0  # Density
cp = 1.0   # Specific heat capacity
alpha = k / (rho * cp)  # Thermal diffusivity

# Compute maximum stable time step
dt_max = calculate_max_stable_dt(alpha, dx, dy)
dt = min(dt_max, desired_dt)  # Use the smaller of the desired or maximum stable time step

# Initializing the temperature field on the GPU
T = mx.zeros([nx, ny])
T_old = mx.zeros_like(T)

# Applying Dirichlet boundary conditions
T[:, 0] = 0.0   # Set left boundary temperature
T[:, -1] = 1.0  # Set right boundary temperature

# Time-stepping loop for the heat equation

dysq = dy**2
dxsq = dx**2
mx.eval(T)
mx.eval(T_old)

start_time = time.time()  # Capture start time
for step in range(max_steps):
    #T_old = mx.broadcast_to(T, shape=T.shape)
    T_old = T
    # Update interior points using finite difference method
    # Pad the interior points for broadcasting

    #### Segfault Control ####
    # if step % 15000 == 0: mx.eval(T_old)    #(When commented out it should cause a segfault)

    T_mid = T_old[1:-1,1:-1]

    T =  mx.pad(mx.pad((T_mid + dt * k * (
        (T_old[2:, 1:-1] - 2 * T_mid + T_old[:-2, 1:-1]) / dxsq +
        (T_old[1:-1, 2:] - 2 * T_mid + T_old[1:-1, :-2]) / dysq
    )), ((0,0),(0,1)),1),((0,0),(1,0)), 0)

    # Update Neumann boundaries (zero-flux) at top and bottom
    T = mx.concatenate([mx.expand_dims(T[0, :], (0)), T, mx.expand_dims(T[-1, :], (-0))], axis=0)

end_time = time.time()  # Capture end time
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time:.2f} seconds")

mx.eval(T)

# Visualizing the temperature field using matplotlib
import matplotlib.pyplot as plt
plt.imshow(T, cmap='hot', interpolation='nearest')
plt.colorbar()  # Add a color bar to indicate temperature scales
plt.show()
sck-at-ucy commented 11 months ago

One more observation that might provide a clue. These two slightly different implementations (with the problem size nx, ny kept the same between them , require different frequencies of mx.eval(T_old) to avoid segfault. The difference between the two is that in the first case mx.pad is used to add the top and bottom rows and then assignment is made for the Neumann BCs, while in the second, the rows are added and assigned values via mx.concatenate

Implementation 1: eval every 21000 steps

    if step % 21000 == 0:
        mx.eval(T_old)

    T_mid = T_old[1:-1, 1:-1]

    T = mx.pad(mx.pad((T_mid + dt * k * (
            (T_old[2:, 1:-1] - 2 * T_mid + T_old[:-2, 1:-1]) / dxsq +
            (T_old[1:-1, 2:] - 2 * T_mid + T_old[1:-1, :-2]) / dysq
    )), ((1, 1), (0, 1)), 1), ((0, 0), (1, 0)), 0)

    # Update Neumann boundaries (zero-flux) at top and bottom
    T[0, :] = T[1, :]
    T[-1, :] = T[-2, :]

Implementation 2: eval every 16000 steps

    if step % 16000 == 0:
        mx.eval(T_old)

    T_mid = T_old[1:-1, 1:-1]

    T = mx.pad(mx.pad((T_mid + dt * k * (
            (T_old[2:, 1:-1] - 2 * T_mid + T_old[:-2, 1:-1]) / dxsq +
            (T_old[1:-1, 2:] - 2 * T_mid + T_old[1:-1, :-2]) / dysq
    )), ((0, 0), (0, 1)), 1), ((0, 0), (1, 0)), 0)

    # Update Neumann boundaries (zero-flux) at top and bottom
    T = mx.concatenate([mx.expand_dims(T[0, :], (0)), T, mx.expand_dims(T[-1, :], (-0))], axis=0)
awni commented 11 months ago

I am able to reproduce the segfault on my machine.

Current (and likely) hypothesis is that the segfault is a result of a stack overflow during the eval when we recurse on the inputs to build the compute graph. It definitely makes sense that it depends on the size of the graph..

For now, I would recommend just doing a fixed amount of compute per eval, so adding something like:

if step % 1000 == 0:
    mx.eval(T_old)

I am not sure if we should fix this.. we could try using an iterative graph construction to compute the graph, but at the same time I don't recommend letting graphs get so big. It's usually a sign that eval should be used more frequently and/or the number of operations should be reduced.

Another comment:

You should see if it's possible to replace the whole inner computation with something like a convolution or matmul with the appropriate kernel. It would dramatically reduce the number of operations which in this case would speed things up substantially. I think the whole update from T_old to T is linear so it should be very doable with the right operation.

sck-at-ucy commented 11 months ago

Ok, at least I now have some level of understanding as to what is going.

The idea to use a convolution kernel is a great one I will try it because it connects to the next project I want to tackle with MLX for my computational engineering course. I know how I can do that with Torch I'll see if I can make it work with MLX and will come back for help if I run into trouble.

BTW, if you like the idea, it might be useful to have a section of examples on engineering/physics applications (not necessarily falling directly under the ML/AI area) to provide an easy entry point for people wanting to leverage MLX for science. I would be happy to contribute in this direction.

Here's the latest Blackened version of the code before I implement the suggested change:

2D_Heat_Equation_MLX_A.py
1    # Solve the 2D Heat Conduction Equation with 2 Neumann and 2 Dirichlet BCs
2    # with the Apple MLX framework.
3    # Stavros Kassinos
4    # Example for course MME 419 @ UCY, December 2023
5    import time
6    import mlx.core as mx
7    
8    # Select Execution Device
9    mx.set_default_device(mx.gpu)
10   
11   # Grid size and material properties setup
12   nx, ny = 1_000, 1_000  # Set grid dimensions
13   k = 1.0  # Thermal conductivity
14   
15   # Time-stepping parameters
16   desired_dt = 0.01  # Desired time step
17   max_steps = 50_000  # Maximum number of time steps
18   
19   # Creating a linearly spaced grid
20   x = mx.arange(0, 1.0, 1.0 / nx)
21   y = mx.arange(0, 1.0, 1.0 / ny)
22   
23   dx = x[1] - x[0]  # Grid spacing in x direction
24   dy = y[1] - y[0]  # Grid spacing in y direction
25   
26   
27   # Function to calculate the maximum stable time step for the explicit Euler method
28   def calculate_max_stable_dt(alpha, dx, dy):
29       return (1 / (2 * alpha)) * (1 / (1 / dx**2 + 1 / dy**2))
30   
31   
32   # Material properties for stability calculation
33   rho = 1.0  # Density
34   cp = 1.0  # Specific heat capacity
35   alpha = k / (rho * cp)  # Thermal diffusivity
36   
37   # Compute maximum stable time step
38   dt_max = calculate_max_stable_dt(alpha, dx, dy)
39   dt = min(
40       dt_max, desired_dt
41   )  # Use the smaller of the desired or maximum stable time step
42   
43   # Initializing the temperature field on the GPU
44   T = mx.zeros([nx, ny])
45   T_old = mx.zeros_like(T)
46   
47   # Applying Dirichlet boundary conditions
48   T[:, 0] = 0.0  # Set left boundary temperature
49   T[:, -1] = 1.0  # Set right boundary temperature
50   
51   # Time-stepping loop for the heat equation
52   
53   dysq = dy**2
54   dxsq = dx**2
55   
56   start_time = time.time()  # Capture start time
57   for step in range(max_steps):
58       T_old = T  # Update previous step T-field
59   
60       # Eval to keep execution graph efficient
61       if step % 20_000 == 0:
62           mx.eval(T)
63   
64       # Update interior points using finite difference method
65       #    Pad the interior points for broadcasting to T.shape
66       #    Enforce Dirichlet BCs
67   
68       T_mid = T_old[1:-1, 1:-1]
69   
70       T = mx.pad(
71           mx.pad(
72               (
73                   T_mid
74                   + dt
75                   * k
76                   * (
77                       (T_old[2:, 1:-1] - 2 * T_mid + T_old[:-2, 1:-1]) / dxsq
78                       + (T_old[1:-1, 2:] - 2 * T_mid + T_old[1:-1, :-2]) / dysq
79                   )
80               ),
81               ((1, 1), (0, 1)),
82               1,
83           ),
84           ((0, 0), (1, 0)),
85           0,
86       )
87   
88       # Update Neumann boundaries (zero-flux) at top and bottom
89       T[0, :] = T[1, :]
90       T[-1, :] = T[-2, :]
91   
92   end_time = time.time()  # Capture end time
93   elapsed_time = end_time - start_time
94   print(f"Elapsed time: {elapsed_time:.2f} seconds")
95   
96   mx.eval(T)
97   
98   # Visualizing the temperature field using matplotlib
99   import matplotlib.pyplot as plt
100  
101  plt.imshow(T, cmap="hot", interpolation="nearest")
102  plt.colorbar()  # Add a color bar to indicate temperature scales
103  plt.show()
sck-at-ucy commented 11 months ago

I converted the code to use mx.conv2d and to my surprise it is taking about 3 times longer than the version shared just above in my last comment. Am I doing something wrong? I was confident that I would be seeing time savings...

import time
import mlx.core as mx

# Set the GPU as the default execution device for MLX operations
mx.set_default_device(mx.gpu)

# Grid size setup (1000x1000) and thermal conductivity definition
nx, ny = 1_000, 1_000
k = 1.0

# Time-stepping parameters: desired time step and maximum number of steps
desired_dt = 0.01
max_steps = 50_000

# Create linearly spaced grids for x and y, and compute grid spacings
x = mx.arange(0, 1.0, 1.0 / nx)
y = mx.arange(0, 1.0, 1.0 / ny)
dx = x[1] - x[0]
dy = y[1] - y[0]

# Function to calculate maximum stable time step for explicit Euler method
def calculate_max_stable_dt(alpha, dx, dy):
    return (1 / (2 * alpha)) * (1 / (1 / dx**2 + 1 / dy**2))

# Material properties: density and specific heat capacity
rho = 1.0
cp = 1.0
alpha = k / (rho * cp)  # Thermal diffusivity

# Compute the maximum stable time step and choose the smaller of the two
dt_max = calculate_max_stable_dt(alpha, dx, dy)
dt = min(dt_max, desired_dt)

# Initialize temperature field arrays on the GPU
T = mx.zeros([nx, ny])
T_old = mx.zeros_like(T)

# Apply Dirichlet boundary conditions (fixed temperatures at boundaries)
T[:, 0] = 0.0  # Left boundary
T[:, -1] = 1.0  # Right boundary

# Precompute squared grid spacings for use in the convolution kernel
dxsq = dx**2
dysq = dy**2

# Define the convolution kernel for the finite difference method
# Kernel encodes the discretized Laplacian operator
kernel = mx.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]]) * (dt * k / dxsq)
kernel = kernel.reshape(1, 3, 3, 1)  # Reshape to match conv2d requirements

# Reshape T for convolution, adding batch and channel dimensions
T_reshaped = mx.expand_dims(T, (0, -1))  # New shape: (1, 1000, 1000, 1)

# Start timer for performance measurement
start_time = time.time()

# Main time-stepping loop
for step in range(max_steps):
    # Periodically evaluate the computation graph for efficiency
    if step % 20_000 == 0:
        mx.eval(T_reshaped)

    # Convolution operation to update the interior points
    T_reshaped = T_reshaped + mx.conv2d(T_reshaped, kernel, padding=(1, 1))

    # Reapply Dirichlet boundary conditions after convolution
    T_reshaped[:, :, 0, :] = 0.0  # Left boundary
    T_reshaped[:, :, -1, :] = 1.0  # Right boundary

    # Reapply Neumann boundary conditions (zero-flux) at top and bottom
    T_reshaped[:, 0, :, :] = T_reshaped[:, 1, :, :]
    T_reshaped[:, -1, :, :] = T_reshaped[:, -2, :, :]

# End timer and calculate elapsed time
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time:.2f} seconds")

# Reshape T back to its original 2D form and evaluate
T = mx.squeeze(T_reshaped, axis=[0, -1])
mx.eval(T)

# Visualization of the temperature field
import matplotlib.pyplot as plt

plt.imshow(T, cmap="hot", interpolation="nearest")
plt.colorbar()  # Add color bar to indicate temperature scales
plt.show()
awni commented 11 months ago

It looks like we have some work to for those shapes for our conv2D. I think your implementation has much more potential to be fast. If we can make the conv comparable to torch it should really sing! In the meantime, bear with us..we're a small team with a lot to do. But things will only get much better from here!

I wrote a little benchmark and for the sizes in your problem we are > 10x slower than Torch's conv2d. CC @jagrit06 who will hopefully have some time to work on this in the near future.

import numpy as np
import torch.nn.functional
import mlx.core as mx
import time

### Time Torch
T = np.random.randn(1000, 1000).astype(np.float32)
W = np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]], np.float32)

T = T.reshape(1, 1, 1000, 1000)
W = W.reshape(1, 1, 3, 3)

device = torch.device("mps")
T = torch.tensor(T).to(device)
W = torch.tensor(W).to(device)

for _ in range(5):
    T = torch.nn.functional.conv2d(T, W, padding="same")
torch.mps.synchronize()

tic = time.time()
for _ in range(100):
    T = torch.nn.functional.conv2d(T, W, padding="same")
torch.mps.synchronize()
toc = time.time()

print(f"Torch: {toc - tic}")

### Time MLX
T = np.random.randn(1000, 1000).astype(np.float32).reshape(1, 1000, 1000, 1)
W = np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]], np.float32).reshape(1, 3, 3, 1)
T = mx.array(T)
W = mx.array(W)

for _ in range(5):
    T = mx.conv2d(T, W, padding=(1, 1))
mx.eval(T)

tic = time.time()
for _ in range(100):
    T = mx.conv2d(T, W, padding=(1, 1))
mx.eval(T)
toc = time.time()
print(f"MLX: {toc - tic}")
awni commented 11 months ago

it might be useful to have a section of examples on engineering/physics applications

That's a great idea! We could potentially put a physics/ directory in the MLX examples repo, or make something standalone. Since I don't know much about those use cases, if you're are interested in this, maybe a good starting point is to make your own repo with some of these examples + comments / explanations, and we can go from there.

sck-at-ucy commented 11 months ago
  1. Great it would be awesome to have a fast conv2d, but I do realize how full your hands are right now, it is actually amazing how fast things are progressing @jagrit06

  2. I forked the repo and will try to create a library of engineering/physics examples of increasing complexity and once I have sufficient material in place you could review and decide whether to merge or keep it separate.

sck-at-ucy commented 11 months ago

Forgive my ignorance, but I'm trying to understand the following: if you move the

torch.mps.synchronize() and mx.eval(T)

statements after the respective toc = time.time() actually MLX is way faster (30X). Does this mean that MLX is faster in creating the comp graphs but MLX conv2D is slower when it is time to do the actual computation? @awni

import numpy as np
import torch.nn.functional
import mlx.core as mx
import time

### Time Torch
T = np.random.randn(1000, 1000).astype(np.float32)
W = np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]], np.float32)

T = T.reshape(1, 1, 1000, 1000)
W = W.reshape(1, 1, 3, 3)

device = torch.device("mps")
T = torch.tensor(T).to(device)
W = torch.tensor(W).to(device)

for _ in range(5):
    T = torch.nn.functional.conv2d(T, W, padding="same")
torch.mps.synchronize()

tic = time.time()
for _ in range(100):
    T = torch.nn.functional.conv2d(T, W, padding="same")
toc = time.time()
torch.mps.synchronize()
print(f"Torch: {toc - tic}")

### Time MLX
T = np.random.randn(1000, 1000).astype(np.float32).reshape(1, 1000, 1000, 1)
W = np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]], np.float32).reshape(1, 3, 3, 1)
T = mx.array(T)
W = mx.array(W)

for _ in range(5):
    T = mx.conv2d(T, W, padding=(1, 1))
mx.eval(T)

tic = time.time()
for _ in range(100):
    T = mx.conv2d(T, W, padding=(1, 1))
toc = time.time()
mx.eval(T)
print(f"MLX: {toc - tic}")

Torch: 0.005361795425415039
MLX: 0.00016999244689941406

Process finished with exit code 0
awni commented 11 months ago

Does this mean that MLX is faster in creating the comp graphs but MLX conv2D is slower when it is time to do the actual computation?

Exactly right.

jagrit06 commented 11 months ago

To add to the conversation about conv2D - it's a process of creating and optimizing many specialization Even currently, the shapes that our Winograd conv supports will run faster than alternatives It's work in progress to get a wider range of shapes to good speeds and it is no small priority for us!

sck-at-ucy commented 11 months ago

So great to hear and much appreciated!

sck-at-ucy commented 11 months ago

Correction: by the way the same kind of benchmark for mx.matmut vs torch.dot shows that MLX is about 5X faster than Torch for the dot product of two vectors for array sizes of ~ 1_000, about MLX is 1.5X faster than Torch for array sizes ~ 10_000 but Torch is about 5X faster than MLX for array sizes ~ 100_000

sck-at-ucy commented 11 months ago

I am able to reproduce the segfault on my machine.

Current (and likely) hypothesis is that the segfault is a result of a stack overflow during the eval when we recurse on the inputs to build the compute graph. It definitely makes sense that it depends on the size of the graph..

For now, I would recommend just doing a fixed amount of compute per eval, so adding something like:

if step % 1000 == 0:
    mx.eval(T_old)

I am not sure if we should fix this.. we could try using an iterative graph construction to compute the graph, but at the same time I don't recommend letting graphs get so big. It's usually a sign that eval should be used more frequently and/or the number of operations should be reduced.

Another comment:

You should see if it's possible to replace the whole inner computation with something like a convolution or matmul with the appropriate kernel. It would dramatically reduce the number of operations which in this case would speed things up substantially. I think the whole update from T_old to T is linear so it should be very doable with the right operation.

Just adding an additional observation that perhaps points to how the graphs are evaluated. The two implementations below differ by just a superfluous set of parentheses, but the one with the extra parentheses seems to group operations differently and results in being able to eval with half the frequency compared to the implementation without the extra set of parentheses.

#Version with extra parentheses 
    T = mx.pad( mx.pad(
        (T_mid + (
          dt_alpha_ov_dxsq * ( T_old[2:, 1:-1] - 2 * T_mid + T_old[:-2, 1:-1] )
        + dt_alpha_ov_dysq * ( T_old[1:-1, 2:] - 2 * T_mid + T_old[1:-1, :-2] ))),
        ((0, 0), (0, 1)),1), ((1, 1), (1, 0)),0)

# Version without the superfluous parentheses
    T = mx.pad( mx.pad(
        (T_mid + 
          dt_alpha_ov_dxsq * ( T_old[2:, 1:-1] - 2 * T_mid + T_old[:-2, 1:-1] )
        + dt_alpha_ov_dysq * ( T_old[1:-1, 2:] - 2 * T_mid + T_old[1:-1, :-2] )),
        ((0, 0), (0, 1)),1), ((1, 1), (1, 0)),0)