Open sck-at-ucy opened 11 months ago
It should never segfault, so that's not something you are doing sub-optimally with eval
. It looks like a bug to me.. but we'll have to investigate further.
Ok knowing that, I will try to do some debugging and see if I can provide some additional feedback.
So for some reason it segfaults if I don't put the mx.eval(T) at the beginning (e.g. before starting the loop). That definitely smells like a bug. But otherwise you don't need an eval until the end.
Also I made a couple of slight modifications, it should be noticeably faster (assuming I didn't mess anything up).
# Solving the 2D Heat Conduction Equation with 2 Neumann and 2 Dirichlet PCs
import numpy as np
import time
import mlx.core as mx
# Grid size and material properties setup
nx, ny = 5000, 5000 # Set grid dimensions
k = 1.0 # Thermal conductivity
# Time-stepping parameters
desired_dt = 0.01 # Desired time step
max_steps = 10000 # Maximum number of time steps
# Creating a linearly spaced grid
x = mx.array(np.linspace(0,1,nx))
y = mx.array(np.linspace(0,1,ny))
dx = x[1] - x[0] # Grid spacing in x direction
dy = y[1] - y[0] # Grid spacing in y direction
# Function to calculate the maximum stable time step for the explicit Euler method
def calculate_max_stable_dt(alpha, dx, dy):
return (1 / (2 * alpha)) * (1 / (1/dx**2 + 1/dy**2))
# Material properties for stability calculation
rho = 1.0 # Density
cp = 1.0 # Specific heat capacity
alpha = k / (rho * cp) # Thermal diffusivity
# Compute maximum stable time step
dt_max = calculate_max_stable_dt(alpha, dx, dy)
dt = min(dt_max, desired_dt) # Use the smaller of the desired or maximum stable time step
# Initializing the temperature field on the GPU
T = mx.zeros([nx, ny])
T_old = mx.zeros_like(T)
# Applying Dirichlet boundary conditions
T[:, 0] = 0.0 # Set left boundary temperature
T[:, -1] = 1.0 # Set right boundary temperature
# Time-stepping loop for the heat equation
dysq = dy**2
dxsq = dx**2
mx.eval(T)
start_time = time.time() # Capture start time
for step in range(max_steps):
# Update interior points using finite difference method
# Pad the interior points for broadcasting
T_mid = T_old[1:-1,1:-1]
T = mx.pad(mx.pad((T_mid + dt * k * (
(T_old[2:, 1:-1] - 2 * T_mid + T_old[:-2, 1:-1]) / dxsq +
(T_old[1:-1, 2:] - 2 * T_mid + T_old[1:-1, :-2]) / dysq
)), ((0,0),(0,1)),1),((0,0),(1,0)), 0)
# Update Neumann boundaries (zero-flux) at top and bottom
T = mx.concatenate([mx.expand_dims(T[0, :], (0)), T, mx.expand_dims(T[-1, :], (-0))], axis=0)
mx.eval(T)
end_time = time.time() # Capture end time
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time:.2f} seconds")
It runs orders of magnitude faster (438 sec vs 2.37 sec for 100K steps) !!! And whatever the bug is, it's not surfacing any more. I'm thrilled with the speedup π
I was too quick to respond. The T_old temperature field (see the code below) needs to be updated in each time step, that's why I had the line T_old = mx.broadcast_to(T,shape=T.shape)
at the beginning of each time step, which I found was a bit faster than setting T_old=T
. Once I include again this line the problem resurfaces and the speeds go back to more or less what I was having before.
T_mid = T_old[1:-1,1:-1] T = mx.pad(mx.pad((T_mid + dt * k * ( (T_old[2:, 1:-1] - 2 * T_mid + T_old[:-2, 1:-1]) / dxsq + (T_old[1:-1, 2:] - 2 * T_mid + T_old[1:-1, :-2]) / dysq )), ((0,0),(0,1)),1),((0,0),(1,0)), 0)
Oops, sorry about that!
It's very odd that doing T_old = mx.broadcast_to(T,shape=T.shape)
is faster than T_old=T
π€. More to investigate here.
Following your remark, I tested again T_old = mx.broadcast_to(T,shape=T.shape)
and T_old=T
and they are about the same (within load noise). I must have been misled by some other change I made in parallel before. So, this does not be to checked. Sorry for the wrong claim π«£
Ok, I did some basic debugging using the debugger in PyCharm (PyDev). I inserted a breakpoint at the iteration step that I would normally need to execute mx.aval(T)
to avoid segfault. Stepping though all values in the arrays look normal and nothing weird happens. However, when I make the fist step forward from the break-point it takes a long time for variable values to load on the debugger. I can step through the code and complete execution correctly. I am not sure, if this adds any useful hints. Suspicion: Am I effectively, using an eval
by loading the values in the debugger and that's why the code completes correctly without need the explicit mx.aval(T)
π€? To my non-expert eyes this looks like a low-level error in MXL as there is nothing funny happening in the actual array values. Probably something with the dynamic graphs and memory ?
I'm not able to reproduce the segfault. Does it segfault for you using the settings you have in the example above?
Also what hardware / OS are you using?
Yes, but below I also provide the latest version (have added your suggestions) which still segfaults if I don't mx.eval(T)
or mx.eval(T_old)
at step % 150000 == 0
. The exact frequency of needing to 'mx.eval' depends on the problem size (nx,ny).
Model Name: Mac Studio Model Identifier: Mac14,14 Model Number: Z180000DEB/A Chip: Apple M2 Ultra Total Number of Cores: 24 (16 performance and 8 efficiency) Memory: 128 GB System Firmware Version: 10151.41.12 OS Loader Version: 10151.41.12
macOS Sonoma Version 14.1.2
Python 3.11.6 numpy 1.26.2 mlx 0.0.4 Conda 23.11.0 libblas 3.9.0 20_osxarm64_accelerate conda-forge libcblas 3.9.0 20_osxarm64_accelerate conda-forge liblapack 3.9.0 20_osxarm64_accelerate conda-forge
Elapsed time: 2.99 seconds
Process finished with exit code 139 (interrupted by signal 11:SIGSEGV)
# Solving the 2D Heat Conduction Equation with 2 Neumann and 2 Dirichlet PCs
import numpy as np
import time
import mlx.core as mx
# Grid size and material properties setup
nx, ny = 5000, 5000 # Set grid dimensions
k = 1.0 # Thermal conductivity
# Time-stepping parameters
desired_dt = 0.01 # Desired time step
max_steps = 50000 # Maximum number of time steps
# Creating a linearly spaced grid
x = mx.array(np.linspace(0,1,nx))
y = mx.array(np.linspace(0,1,ny))
dx = x[1] - x[0] # Grid spacing in x direction
dy = y[1] - y[0] # Grid spacing in y direction
# Function to calculate the maximum stable time step for the explicit Euler method
def calculate_max_stable_dt(alpha, dx, dy):
return (1 / (2 * alpha)) * (1 / (1/dx**2 + 1/dy**2))
# Material properties for stability calculation
rho = 1.0 # Density
cp = 1.0 # Specific heat capacity
alpha = k / (rho * cp) # Thermal diffusivity
# Compute maximum stable time step
dt_max = calculate_max_stable_dt(alpha, dx, dy)
dt = min(dt_max, desired_dt) # Use the smaller of the desired or maximum stable time step
# Initializing the temperature field on the GPU
T = mx.zeros([nx, ny])
T_old = mx.zeros_like(T)
# Applying Dirichlet boundary conditions
T[:, 0] = 0.0 # Set left boundary temperature
T[:, -1] = 1.0 # Set right boundary temperature
# Time-stepping loop for the heat equation
dysq = dy**2
dxsq = dx**2
mx.eval(T)
mx.eval(T_old)
start_time = time.time() # Capture start time
for step in range(max_steps):
#T_old = mx.broadcast_to(T, shape=T.shape)
T_old = T
# Update interior points using finite difference method
# Pad the interior points for broadcasting
#### Segfault Control ####
# if step % 15000 == 0: mx.eval(T_old) #(When commented out it should cause a segfault)
T_mid = T_old[1:-1,1:-1]
T = mx.pad(mx.pad((T_mid + dt * k * (
(T_old[2:, 1:-1] - 2 * T_mid + T_old[:-2, 1:-1]) / dxsq +
(T_old[1:-1, 2:] - 2 * T_mid + T_old[1:-1, :-2]) / dysq
)), ((0,0),(0,1)),1),((0,0),(1,0)), 0)
# Update Neumann boundaries (zero-flux) at top and bottom
T = mx.concatenate([mx.expand_dims(T[0, :], (0)), T, mx.expand_dims(T[-1, :], (-0))], axis=0)
end_time = time.time() # Capture end time
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time:.2f} seconds")
mx.eval(T)
# Visualizing the temperature field using matplotlib
import matplotlib.pyplot as plt
plt.imshow(T, cmap='hot', interpolation='nearest')
plt.colorbar() # Add a color bar to indicate temperature scales
plt.show()
One more observation that might provide a clue. These two slightly different implementations (with the problem size nx, ny
kept the same between them , require different frequencies of mx.eval(T_old) to avoid segfault. The difference between the two is that in the first case mx.pad
is used to add the top and bottom rows and then assignment is made for the Neumann BCs, while in the second, the rows are added and assigned values via mx.concatenate
if step % 21000 == 0:
mx.eval(T_old)
T_mid = T_old[1:-1, 1:-1]
T = mx.pad(mx.pad((T_mid + dt * k * (
(T_old[2:, 1:-1] - 2 * T_mid + T_old[:-2, 1:-1]) / dxsq +
(T_old[1:-1, 2:] - 2 * T_mid + T_old[1:-1, :-2]) / dysq
)), ((1, 1), (0, 1)), 1), ((0, 0), (1, 0)), 0)
# Update Neumann boundaries (zero-flux) at top and bottom
T[0, :] = T[1, :]
T[-1, :] = T[-2, :]
if step % 16000 == 0:
mx.eval(T_old)
T_mid = T_old[1:-1, 1:-1]
T = mx.pad(mx.pad((T_mid + dt * k * (
(T_old[2:, 1:-1] - 2 * T_mid + T_old[:-2, 1:-1]) / dxsq +
(T_old[1:-1, 2:] - 2 * T_mid + T_old[1:-1, :-2]) / dysq
)), ((0, 0), (0, 1)), 1), ((0, 0), (1, 0)), 0)
# Update Neumann boundaries (zero-flux) at top and bottom
T = mx.concatenate([mx.expand_dims(T[0, :], (0)), T, mx.expand_dims(T[-1, :], (-0))], axis=0)
I am able to reproduce the segfault on my machine.
Current (and likely) hypothesis is that the segfault is a result of a stack overflow during the eval
when we recurse on the inputs to build the compute graph. It definitely makes sense that it depends on the size of the graph..
For now, I would recommend just doing a fixed amount of compute per eval, so adding something like:
if step % 1000 == 0:
mx.eval(T_old)
I am not sure if we should fix this.. we could try using an iterative graph construction to compute the graph, but at the same time I don't recommend letting graphs get so big. It's usually a sign that eval
should be used more frequently and/or the number of operations should be reduced.
Another comment:
You should see if it's possible to replace the whole inner computation with something like a convolution or matmul with the appropriate kernel. It would dramatically reduce the number of operations which in this case would speed things up substantially. I think the whole update from T_old
to T
is linear so it should be very doable with the right operation.
Ok, at least I now have some level of understanding as to what is going.
The idea to use a convolution kernel is a great one I will try it because it connects to the next project I want to tackle with MLX for my computational engineering course. I know how I can do that with Torch I'll see if I can make it work with MLX and will come back for help if I run into trouble.
BTW, if you like the idea, it might be useful to have a section of examples on engineering/physics applications (not necessarily falling directly under the ML/AI area) to provide an easy entry point for people wanting to leverage MLX for science. I would be happy to contribute in this direction.
Here's the latest Blackened version of the code before I implement the suggested change:
2D_Heat_Equation_MLX_A.py
1 # Solve the 2D Heat Conduction Equation with 2 Neumann and 2 Dirichlet BCs
2 # with the Apple MLX framework.
3 # Stavros Kassinos
4 # Example for course MME 419 @ UCY, December 2023
5 import time
6 import mlx.core as mx
7
8 # Select Execution Device
9 mx.set_default_device(mx.gpu)
10
11 # Grid size and material properties setup
12 nx, ny = 1_000, 1_000 # Set grid dimensions
13 k = 1.0 # Thermal conductivity
14
15 # Time-stepping parameters
16 desired_dt = 0.01 # Desired time step
17 max_steps = 50_000 # Maximum number of time steps
18
19 # Creating a linearly spaced grid
20 x = mx.arange(0, 1.0, 1.0 / nx)
21 y = mx.arange(0, 1.0, 1.0 / ny)
22
23 dx = x[1] - x[0] # Grid spacing in x direction
24 dy = y[1] - y[0] # Grid spacing in y direction
25
26
27 # Function to calculate the maximum stable time step for the explicit Euler method
28 def calculate_max_stable_dt(alpha, dx, dy):
29 return (1 / (2 * alpha)) * (1 / (1 / dx**2 + 1 / dy**2))
30
31
32 # Material properties for stability calculation
33 rho = 1.0 # Density
34 cp = 1.0 # Specific heat capacity
35 alpha = k / (rho * cp) # Thermal diffusivity
36
37 # Compute maximum stable time step
38 dt_max = calculate_max_stable_dt(alpha, dx, dy)
39 dt = min(
40 dt_max, desired_dt
41 ) # Use the smaller of the desired or maximum stable time step
42
43 # Initializing the temperature field on the GPU
44 T = mx.zeros([nx, ny])
45 T_old = mx.zeros_like(T)
46
47 # Applying Dirichlet boundary conditions
48 T[:, 0] = 0.0 # Set left boundary temperature
49 T[:, -1] = 1.0 # Set right boundary temperature
50
51 # Time-stepping loop for the heat equation
52
53 dysq = dy**2
54 dxsq = dx**2
55
56 start_time = time.time() # Capture start time
57 for step in range(max_steps):
58 T_old = T # Update previous step T-field
59
60 # Eval to keep execution graph efficient
61 if step % 20_000 == 0:
62 mx.eval(T)
63
64 # Update interior points using finite difference method
65 # Pad the interior points for broadcasting to T.shape
66 # Enforce Dirichlet BCs
67
68 T_mid = T_old[1:-1, 1:-1]
69
70 T = mx.pad(
71 mx.pad(
72 (
73 T_mid
74 + dt
75 * k
76 * (
77 (T_old[2:, 1:-1] - 2 * T_mid + T_old[:-2, 1:-1]) / dxsq
78 + (T_old[1:-1, 2:] - 2 * T_mid + T_old[1:-1, :-2]) / dysq
79 )
80 ),
81 ((1, 1), (0, 1)),
82 1,
83 ),
84 ((0, 0), (1, 0)),
85 0,
86 )
87
88 # Update Neumann boundaries (zero-flux) at top and bottom
89 T[0, :] = T[1, :]
90 T[-1, :] = T[-2, :]
91
92 end_time = time.time() # Capture end time
93 elapsed_time = end_time - start_time
94 print(f"Elapsed time: {elapsed_time:.2f} seconds")
95
96 mx.eval(T)
97
98 # Visualizing the temperature field using matplotlib
99 import matplotlib.pyplot as plt
100
101 plt.imshow(T, cmap="hot", interpolation="nearest")
102 plt.colorbar() # Add a color bar to indicate temperature scales
103 plt.show()
I converted the code to use mx.conv2d
and to my surprise it is taking about 3 times longer than the version shared just above in my last comment. Am I doing something wrong? I was confident that I would be seeing time savings...
import time
import mlx.core as mx
# Set the GPU as the default execution device for MLX operations
mx.set_default_device(mx.gpu)
# Grid size setup (1000x1000) and thermal conductivity definition
nx, ny = 1_000, 1_000
k = 1.0
# Time-stepping parameters: desired time step and maximum number of steps
desired_dt = 0.01
max_steps = 50_000
# Create linearly spaced grids for x and y, and compute grid spacings
x = mx.arange(0, 1.0, 1.0 / nx)
y = mx.arange(0, 1.0, 1.0 / ny)
dx = x[1] - x[0]
dy = y[1] - y[0]
# Function to calculate maximum stable time step for explicit Euler method
def calculate_max_stable_dt(alpha, dx, dy):
return (1 / (2 * alpha)) * (1 / (1 / dx**2 + 1 / dy**2))
# Material properties: density and specific heat capacity
rho = 1.0
cp = 1.0
alpha = k / (rho * cp) # Thermal diffusivity
# Compute the maximum stable time step and choose the smaller of the two
dt_max = calculate_max_stable_dt(alpha, dx, dy)
dt = min(dt_max, desired_dt)
# Initialize temperature field arrays on the GPU
T = mx.zeros([nx, ny])
T_old = mx.zeros_like(T)
# Apply Dirichlet boundary conditions (fixed temperatures at boundaries)
T[:, 0] = 0.0 # Left boundary
T[:, -1] = 1.0 # Right boundary
# Precompute squared grid spacings for use in the convolution kernel
dxsq = dx**2
dysq = dy**2
# Define the convolution kernel for the finite difference method
# Kernel encodes the discretized Laplacian operator
kernel = mx.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]]) * (dt * k / dxsq)
kernel = kernel.reshape(1, 3, 3, 1) # Reshape to match conv2d requirements
# Reshape T for convolution, adding batch and channel dimensions
T_reshaped = mx.expand_dims(T, (0, -1)) # New shape: (1, 1000, 1000, 1)
# Start timer for performance measurement
start_time = time.time()
# Main time-stepping loop
for step in range(max_steps):
# Periodically evaluate the computation graph for efficiency
if step % 20_000 == 0:
mx.eval(T_reshaped)
# Convolution operation to update the interior points
T_reshaped = T_reshaped + mx.conv2d(T_reshaped, kernel, padding=(1, 1))
# Reapply Dirichlet boundary conditions after convolution
T_reshaped[:, :, 0, :] = 0.0 # Left boundary
T_reshaped[:, :, -1, :] = 1.0 # Right boundary
# Reapply Neumann boundary conditions (zero-flux) at top and bottom
T_reshaped[:, 0, :, :] = T_reshaped[:, 1, :, :]
T_reshaped[:, -1, :, :] = T_reshaped[:, -2, :, :]
# End timer and calculate elapsed time
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time:.2f} seconds")
# Reshape T back to its original 2D form and evaluate
T = mx.squeeze(T_reshaped, axis=[0, -1])
mx.eval(T)
# Visualization of the temperature field
import matplotlib.pyplot as plt
plt.imshow(T, cmap="hot", interpolation="nearest")
plt.colorbar() # Add color bar to indicate temperature scales
plt.show()
It looks like we have some work to for those shapes for our conv2D
. I think your implementation has much more potential to be fast. If we can make the conv comparable to torch it should really sing! In the meantime, bear with us..we're a small team with a lot to do. But things will only get much better from here!
I wrote a little benchmark and for the sizes in your problem we are > 10x slower than Torch's conv2d. CC @jagrit06 who will hopefully have some time to work on this in the near future.
import numpy as np
import torch.nn.functional
import mlx.core as mx
import time
### Time Torch
T = np.random.randn(1000, 1000).astype(np.float32)
W = np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]], np.float32)
T = T.reshape(1, 1, 1000, 1000)
W = W.reshape(1, 1, 3, 3)
device = torch.device("mps")
T = torch.tensor(T).to(device)
W = torch.tensor(W).to(device)
for _ in range(5):
T = torch.nn.functional.conv2d(T, W, padding="same")
torch.mps.synchronize()
tic = time.time()
for _ in range(100):
T = torch.nn.functional.conv2d(T, W, padding="same")
torch.mps.synchronize()
toc = time.time()
print(f"Torch: {toc - tic}")
### Time MLX
T = np.random.randn(1000, 1000).astype(np.float32).reshape(1, 1000, 1000, 1)
W = np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]], np.float32).reshape(1, 3, 3, 1)
T = mx.array(T)
W = mx.array(W)
for _ in range(5):
T = mx.conv2d(T, W, padding=(1, 1))
mx.eval(T)
tic = time.time()
for _ in range(100):
T = mx.conv2d(T, W, padding=(1, 1))
mx.eval(T)
toc = time.time()
print(f"MLX: {toc - tic}")
it might be useful to have a section of examples on engineering/physics applications
That's a great idea! We could potentially put a physics/
directory in the MLX examples repo, or make something standalone. Since I don't know much about those use cases, if you're are interested in this, maybe a good starting point is to make your own repo with some of these examples + comments / explanations, and we can go from there.
Great it would be awesome to have a fast conv2d, but I do realize how full your hands are right now, it is actually amazing how fast things are progressing @jagrit06
I forked the repo and will try to create a library of engineering/physics examples of increasing complexity and once I have sufficient material in place you could review and decide whether to merge or keep it separate.
Forgive my ignorance, but I'm trying to understand the following: if you move the
torch.mps.synchronize()
and mx.eval(T)
statements after the respective toc = time.time()
actually MLX is way faster (30X). Does this mean that MLX is faster in creating the comp graphs but MLX conv2D is slower when it is time to do the actual computation? @awni
import numpy as np
import torch.nn.functional
import mlx.core as mx
import time
### Time Torch
T = np.random.randn(1000, 1000).astype(np.float32)
W = np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]], np.float32)
T = T.reshape(1, 1, 1000, 1000)
W = W.reshape(1, 1, 3, 3)
device = torch.device("mps")
T = torch.tensor(T).to(device)
W = torch.tensor(W).to(device)
for _ in range(5):
T = torch.nn.functional.conv2d(T, W, padding="same")
torch.mps.synchronize()
tic = time.time()
for _ in range(100):
T = torch.nn.functional.conv2d(T, W, padding="same")
toc = time.time()
torch.mps.synchronize()
print(f"Torch: {toc - tic}")
### Time MLX
T = np.random.randn(1000, 1000).astype(np.float32).reshape(1, 1000, 1000, 1)
W = np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]], np.float32).reshape(1, 3, 3, 1)
T = mx.array(T)
W = mx.array(W)
for _ in range(5):
T = mx.conv2d(T, W, padding=(1, 1))
mx.eval(T)
tic = time.time()
for _ in range(100):
T = mx.conv2d(T, W, padding=(1, 1))
toc = time.time()
mx.eval(T)
print(f"MLX: {toc - tic}")
Torch: 0.005361795425415039
MLX: 0.00016999244689941406
Process finished with exit code 0
Does this mean that MLX is faster in creating the comp graphs but MLX conv2D is slower when it is time to do the actual computation?
Exactly right.
To add to the conversation about conv2D - it's a process of creating and optimizing many specialization Even currently, the shapes that our Winograd conv supports will run faster than alternatives It's work in progress to get a wider range of shapes to good speeds and it is no small priority for us!
So great to hear and much appreciated!
Correction: by the way the same kind of benchmark for mx.matmut
vs torch.dot
shows that MLX is about 5X faster than Torch for the dot product of two vectors for array sizes of ~ 1_000, about MLX is 1.5X faster than Torch for array sizes ~ 10_000 but Torch is about 5X faster than MLX for array sizes ~ 100_000
I am able to reproduce the segfault on my machine.
Current (and likely) hypothesis is that the segfault is a result of a stack overflow during the
eval
when we recurse on the inputs to build the compute graph. It definitely makes sense that it depends on the size of the graph..For now, I would recommend just doing a fixed amount of compute per eval, so adding something like:
if step % 1000 == 0: mx.eval(T_old)
I am not sure if we should fix this.. we could try using an iterative graph construction to compute the graph, but at the same time I don't recommend letting graphs get so big. It's usually a sign that
eval
should be used more frequently and/or the number of operations should be reduced.Another comment:
You should see if it's possible to replace the whole inner computation with something like a convolution or matmul with the appropriate kernel. It would dramatically reduce the number of operations which in this case would speed things up substantially. I think the whole update from
T_old
toT
is linear so it should be very doable with the right operation.
Just adding an additional observation that perhaps points to how the graphs are evaluated. The two implementations below differ by just a superfluous set of parentheses, but the one with the extra parentheses seems to group operations differently and results in being able to eval with half the frequency compared to the implementation without the extra set of parentheses.
#Version with extra parentheses
T = mx.pad( mx.pad(
(T_mid + (
dt_alpha_ov_dxsq * ( T_old[2:, 1:-1] - 2 * T_mid + T_old[:-2, 1:-1] )
+ dt_alpha_ov_dysq * ( T_old[1:-1, 2:] - 2 * T_mid + T_old[1:-1, :-2] ))),
((0, 0), (0, 1)),1), ((1, 1), (1, 0)),0)
# Version without the superfluous parentheses
T = mx.pad( mx.pad(
(T_mid +
dt_alpha_ov_dxsq * ( T_old[2:, 1:-1] - 2 * T_mid + T_old[:-2, 1:-1] )
+ dt_alpha_ov_dysq * ( T_old[1:-1, 2:] - 2 * T_mid + T_old[1:-1, :-2] )),
((0, 0), (0, 1)),1), ((1, 1), (1, 0)),0)
I have implemented a simple solution of the 2D Heat Conduction Equation with 2 Neumann and 2 Dirichlet BCs. I have the code implemented both using PyTorch and the MLX framework and I am testing the relative performance on an M2 Ultra with 128GB memory.
The MLX code is included below. So far, performance in various tests (on the same machine) show the MLX version to be somewhere between X2 and X10 faster depending on the problem size.
However, I have an issue that I need to understand. Depending on the problem size, I need to include the line
if step % 15000== 0: mx.eval(T)
to avoid segmentation fault. I imagine this has to do with the lazy evaluation and arrays being in buffer? My issue is that currently I figure each time how often I need to
mx.eval
empirically. Is there some programmatic and more elegant way to automatically issue themx.eval
at the right frequency based on the problem size?Here is the complete code below. Thank you for all your help @awni !