Open cowolff opened 4 months ago
I am experiencing something similar. I have implemented a model that tries to backtrack the edge boundary conditions in a 2D Heat Conduction problem given as input the final temperature distribution on the interior of the plate (but not the boundary conditions). The Torch version is indeed about 10 times faster, but I haven't done any tracing of the bottleneck. The optimizer is in both cases Adam with the same learning rate.
MLX Version
class BCModel(nn.Module):
def __init__(self, ny, initial_left_bc, initial_right_bc):
super(BCModel, self).__init__()
# Initialize boundary conditions as scalar values
self.left_bc = mx.array([initial_left_bc])
self.right_bc = mx.array([initial_right_bc])
def apply_bc(self, T):
# Apply scalar boundary conditions uniformly across the boundary
T[:, 0] = self.left_bc * mx.ones(ny)
T[:, -1] = self.right_bc * mx.ones(ny)
T[0, :] = T[1, :] # Neumann BC
T[-1, :] = T[-2, :] # Neumann BC
return T
Torch Version
class BCModel(nn.Module):
def __init__(self, ny, initial_left_bc, initial_right_bc):
super(BCModel, self).__init__()
# Define left_bc and right_bc as nn.Parameters (scalars)
self.left_bc = nn.Parameter(torch.tensor(initial_left_bc))
self.right_bc = nn.Parameter(torch.tensor(initial_right_bc))
def apply_bc(self, T, ny):
T[:, 0] = self.left_bc.expand(ny)
T[:, -1] = self.right_bc.expand(ny)
T[0, :] = T[1, :] # Neumann BC
T[-1, :] = T[-2, :] # Neumann BC
return T
Would be great for both cases if you could share the code with the timing you used to be sure we are all measuring the same things!
Happy to share the code, needs a little cleaning first 😅. Coming soon ...
BC_Backtrack_MLX.py.zip BC_Bactrack_Torch.py.zip
I have attached the two versions of the code. Some observations follow:
For the same number of epochs (e.g. 500) Torch achieves a better loss reduction and a better guess of the BCs in about ~450sec while MLX (as I have implemented it 😬) takes about 4X as long and the loss reduction and the BCs guess are worse.
In compute_field_loss()
I could not use max_steps beyond 5000 (roughly) even when I attempted to insert some mx.eval. With Torch there was no issue no matter the value of max_steps.
There is a good chance I am doing some obvious mistakes in implementing MLX as I am trying to learn it. For example, I am not sure if there was a need to do the equivalent of loss.backward() # Compute gradients
separately. Thus, I would be thankful for any comments to help me improve the implementation.
OK, this updated version of the MLX implementation solves the problem with max_steps
, it can now handle arbitrary iterations. However, I only "kind of" understand why my fix worked.
The slowness problem remains, although I haven't tried larger size arrays (nx, ny). In general, in my experience MLX starts to have an advantage on bigger sized problems.
@sck-at-ucy could you say a bit more about what it was doing when it didn't work (i.e. going beyond 5k)? I notice you added:
for step in range(max_steps):
if step % 5000 == 0: mx.eval(T)
Presumably it was crashing? On my machine both of your examples run and seem to produce the same result (I have the max_steps
at the default of 5000
)
If the default max_steps
is set to 5000 in both versions of the MLX code things work. This seems to be an upper limit for which things work. Now, if I want to use a higher number of iterations, in the first version where the iterations are done within the loss function I experience the problem and things stop working. In other words, adding if step % 5000 == 0: mx.eval(T_model)
as shown below does not help:
def compute_field_loss(model, T_reference, nx, ny, dx, dy, dt_alpha_ov_dxsq, dt_alpha_ov_dysq, max_steps, edge_weight):
T_model = mx.mean(T_reference) * mx.ones((nx, ny))
T_model = model.apply_bc(T_model)
#if epoch % 10 == 0: print(f'epoch: {epoch} and {T_model}')
for _ in range(max_steps):
if max_steps % 1000 == 0 : mx.eval(T_model)
T_model = update_fd_method(T_model, dt_alpha_ov_dxsq, dt_alpha_ov_dysq)
T_model = model.apply_bc(T_model)
diff = mx.array(T_model - T_reference)
temp_range = mx.max(T_reference) - mx.min(T_reference)
temp_range = max(temp_range, 1e-6) # Avoid division by zero
normalized_diff = diff / temp_range
# Apply different weights to edge-adjacent and central nodes
edge_mask = mx.ones_like(T_model)
edge_mask[:, 0:4] = 0 # Exclude leftmost Dirichlet boundary
edge_mask[:, -4:] = 0 # Exclude rightmost Dirichlet boundary
edge_mask[:, 1] *= edge_weight # Left edge-adjacent nodes
edge_mask[:, -2] *= edge_weight # Right edge-adjacent nodes
weighted_diff = edge_mask * normalized_diff
loss = mx.sum((weighted_diff[:, 1:-1]) ** 2)
return loss
The code does not crash and does not produce any error messages but execution goes to zombie land. It never completes. If I look at the GPU history, after an initial spike it drops down and seems to wonder around randomly, presumably execution is halted?
Now in the refactored version, the finite difference iterations have been moved from the within the loss function definition into a separate method, and there the addition of if step % 5000 == 0: mx.eval(T)
works, and I can use any value of max_steps
. The difference in the GPU history is clear, now good utilization.
I encountered this before, that the code simply zombies out but not with a hard crash. It would be nice to understand this behavior better so that even if it is not fixable, at least be prepared to write code in a way that it does not trigger it. Now, I had a hunch that moving the iterations out of the loss function would help, but I am not sure why it worked, so next time I might trigger it again.
I am adding the latest version of the code that includes some corrections in the handling of plots and movie generation at the end.
This version includes the saving of frames in the loss function and that needs to be commented out for speed comparisons.
My naive conclusion: including the finite difference iterations directly in the loss function makes the graphs more complex to track. But the zombie behavior is concerning.
The slowness for this size of problem relative to torch is an open question. Will try larger problems to see how the comparison goes.
As expected, increasing the problem size (nx * ny)
to 256 x 256 (from the previous 64 x 64 ) shows that for the same number of epochs and similar loss performance MLX takes about 251 sec vs Torch taking about 377 sec. This has been my experience in other cases, for small problems Torch has an advantage but that is quickly reversed as the problem size increases.
We'd like to get better at small sizes too. Thanks for the detailed benchmark! Perf is a high priority right now and more benchmarks to examine are very helpful! I'll keep you posted on any findings with the one you submitted. The hang is very strange and should not happen..
Would be happy to run any additional benchmarks with these codes if that would be useful. I am still exploring, if I notice any worthy insights I will share.
I now also published my code on GitHub and provided a requirements.txt for you to recreate my results: https://github.com/cowolff/ppo_mlx As mentioned earlier, in ppo_mlx.py the line 245 takes much longer than the equivalent in PyTorch:
loss, grads = loss_and_grad_fn(agent, pg_loss, ent_coef, entropy_loss, v_loss, vf_coef)
Please also note that my code still has a few other issues, resulting in my PPO implementation not yet actually learning anything. But that doesn't change the fact, that calculating the loss and gradients take very long. Thanks for your help!!
I am implementing a version of PPO in MLX and wanted to benchmark it against my PyTorch implementation. Sadly, the performance (samples per second) was really quite bad, so I benchmarked all the different parts. Turns out that the inference performance during sampling is not the issue (quite the contrary) and also my loss calculation is really fast compared to PyTorch. It just takes veeery long to compute the gradients compared to PyTorch.
My agent model:
My loss function:
This part takes forever (more than 10 times longer than the PyTorch equivalent):
Is there a specific reason why this takes so much longer than the following PyTorch code?
I am using a MacBook Pro M1 Pro with 16 GB RAM and Sonoma 14.2