ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
15.01k stars 856 forks source link

[BUG] value_and_grad is really slow #451

Open cowolff opened 4 months ago

cowolff commented 4 months ago

I am implementing a version of PPO in MLX and wanted to benchmark it against my PyTorch implementation. Sadly, the performance (samples per second) was really quite bad, so I benchmarked all the different parts. Turns out that the inference performance during sampling is not the issue (quite the contrary) and also my loss calculation is really fast compared to PyTorch. It just takes veeery long to compute the gradients compared to PyTorch.

My agent model:

class Agent(nn.Module):
    def __init__(self, envs):
        super(Agent, self).__init__()
        self.critic = nn.Sequential(
            nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 1),
        )
        self.actor = nn.Sequential(
            nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, envs.single_action_space.n),
        )

    def get_value(self, x):
        return self.critic(x)

    def get_action_and_value(self, x, action=None):
        logits = self.actor(x)
        probs = softmax(logits)
        if action is None:
            action = sample_from_categorical(probs)

        return action, log_prob(probs, action), calc_entropy(probs), self.critic(x)

My loss function:

def loss_function(model, pg_loss, ent_coef, entropy_loss, v_loss, vf_coef):
    loss = pg_loss - ent_coef * entropy_loss + v_loss * vf_coef
    return loss
loss_and_grad_fn = nn.value_and_grad(agent, loss_function)

This part takes forever (more than 10 times longer than the PyTorch equivalent):

loss, grads = loss_and_grad_fn(agent, pg_loss, ent_coef, entropy_loss, v_loss, vf_coef)

Is there a specific reason why this takes so much longer than the following PyTorch code?

optimizer.zero_grad()
loss.backward()
optimizer.step()

I am using a MacBook Pro M1 Pro with 16 GB RAM and Sonoma 14.2

sck-at-ucy commented 4 months ago

I am experiencing something similar. I have implemented a model that tries to backtrack the edge boundary conditions in a 2D Heat Conduction problem given as input the final temperature distribution on the interior of the plate (but not the boundary conditions). The Torch version is indeed about 10 times faster, but I haven't done any tracing of the bottleneck. The optimizer is in both cases Adam with the same learning rate.

MLX Version

class BCModel(nn.Module):
    def __init__(self, ny, initial_left_bc, initial_right_bc):
        super(BCModel, self).__init__()
        # Initialize boundary conditions as scalar values
        self.left_bc = mx.array([initial_left_bc])
        self.right_bc = mx.array([initial_right_bc])

    def apply_bc(self, T):
        # Apply scalar boundary conditions uniformly across the boundary
        T[:, 0] = self.left_bc * mx.ones(ny)
        T[:, -1] = self.right_bc * mx.ones(ny)
        T[0, :] = T[1, :]  # Neumann BC
        T[-1, :] = T[-2, :]  # Neumann BC
        return T

Torch Version

class BCModel(nn.Module):
    def __init__(self, ny, initial_left_bc, initial_right_bc):
        super(BCModel, self).__init__()
        # Define left_bc and right_bc as nn.Parameters (scalars)
        self.left_bc = nn.Parameter(torch.tensor(initial_left_bc))
        self.right_bc = nn.Parameter(torch.tensor(initial_right_bc))

    def apply_bc(self, T, ny):
        T[:, 0] = self.left_bc.expand(ny)
        T[:, -1] = self.right_bc.expand(ny)
        T[0, :] = T[1, :]  # Neumann BC
        T[-1, :] = T[-2, :]  # Neumann BC
        return T
awni commented 4 months ago

Would be great for both cases if you could share the code with the timing you used to be sure we are all measuring the same things!

sck-at-ucy commented 4 months ago

Happy to share the code, needs a little cleaning first 😅. Coming soon ...

sck-at-ucy commented 4 months ago

BC_Backtrack_MLX.py.zip BC_Bactrack_Torch.py.zip

I have attached the two versions of the code. Some observations follow:

  1. For the same number of epochs (e.g. 500) Torch achieves a better loss reduction and a better guess of the BCs in about ~450sec while MLX (as I have implemented it 😬) takes about 4X as long and the loss reduction and the BCs guess are worse.

  2. In compute_field_loss() I could not use max_steps beyond 5000 (roughly) even when I attempted to insert some mx.eval. With Torch there was no issue no matter the value of max_steps.

  3. There is a good chance I am doing some obvious mistakes in implementing MLX as I am trying to learn it. For example, I am not sure if there was a need to do the equivalent of loss.backward() # Compute gradients separately. Thus, I would be thankful for any comments to help me improve the implementation.

sck-at-ucy commented 4 months ago

BC_Backtrack_MLX_v2.py.zip

OK, this updated version of the MLX implementation solves the problem with max_steps, it can now handle arbitrary iterations. However, I only "kind of" understand why my fix worked.

The slowness problem remains, although I haven't tried larger size arrays (nx, ny). In general, in my experience MLX starts to have an advantage on bigger sized problems.

awni commented 4 months ago

@sck-at-ucy could you say a bit more about what it was doing when it didn't work (i.e. going beyond 5k)? I notice you added:

    for step in range(max_steps):                                                                                                                              
        if step % 5000 == 0: mx.eval(T)

Presumably it was crashing? On my machine both of your examples run and seem to produce the same result (I have the max_steps at the default of 5000)

sck-at-ucy commented 4 months ago

If the default max_steps is set to 5000 in both versions of the MLX code things work. This seems to be an upper limit for which things work. Now, if I want to use a higher number of iterations, in the first version where the iterations are done within the loss function I experience the problem and things stop working. In other words, adding if step % 5000 == 0: mx.eval(T_model) as shown below does not help:

def compute_field_loss(model, T_reference, nx, ny, dx, dy, dt_alpha_ov_dxsq, dt_alpha_ov_dysq, max_steps, edge_weight):
    T_model = mx.mean(T_reference) * mx.ones((nx, ny))
    T_model = model.apply_bc(T_model)

    #if epoch % 10 == 0: print(f'epoch: {epoch} and {T_model}')
    for _ in range(max_steps):
        if max_steps % 1000 == 0 : mx.eval(T_model)
        T_model = update_fd_method(T_model, dt_alpha_ov_dxsq, dt_alpha_ov_dysq)
        T_model = model.apply_bc(T_model)

    diff = mx.array(T_model - T_reference)
    temp_range = mx.max(T_reference) - mx.min(T_reference)
    temp_range = max(temp_range, 1e-6)  # Avoid division by zero
    normalized_diff = diff / temp_range

    # Apply different weights to edge-adjacent and central nodes
    edge_mask = mx.ones_like(T_model)
    edge_mask[:, 0:4] = 0  # Exclude leftmost Dirichlet boundary
    edge_mask[:, -4:] = 0  # Exclude rightmost Dirichlet boundary
    edge_mask[:, 1] *= edge_weight  # Left edge-adjacent nodes
    edge_mask[:, -2] *= edge_weight  # Right edge-adjacent nodes

    weighted_diff = edge_mask * normalized_diff

    loss = mx.sum((weighted_diff[:, 1:-1]) ** 2)
    return loss

The code does not crash and does not produce any error messages but execution goes to zombie land. It never completes. If I look at the GPU history, after an initial spike it drops down and seems to wonder around randomly, presumably execution is halted?

MLX_version_DoesNotWork

Now in the refactored version, the finite difference iterations have been moved from the within the loss function definition into a separate method, and there the addition of if step % 5000 == 0: mx.eval(T) works, and I can use any value of max_steps. The difference in the GPU history is clear, now good utilization.

MLX_Version_Works

I encountered this before, that the code simply zombies out but not with a hard crash. It would be nice to understand this behavior better so that even if it is not fixable, at least be prepared to write code in a way that it does not trigger it. Now, I had a hunch that moving the iterations out of the loss function would help, but I am not sure why it worked, so next time I might trigger it again.

sck-at-ucy commented 4 months ago

I am adding the latest version of the code that includes some corrections in the handling of plots and movie generation at the end.

This version includes the saving of frames in the loss function and that needs to be commented out for speed comparisons.

My naive conclusion: including the finite difference iterations directly in the loss function makes the graphs more complex to track. But the zombie behavior is concerning.

The slowness for this size of problem relative to torch is an open question. Will try larger problems to see how the comparison goes.

BC_Backtrack_MLX_v2.2.py.zip

sck-at-ucy commented 4 months ago

As expected, increasing the problem size (nx * ny) to 256 x 256 (from the previous 64 x 64 ) shows that for the same number of epochs and similar loss performance MLX takes about 251 sec vs Torch taking about 377 sec. This has been my experience in other cases, for small problems Torch has an advantage but that is quickly reversed as the problem size increases.

awni commented 4 months ago

We'd like to get better at small sizes too. Thanks for the detailed benchmark! Perf is a high priority right now and more benchmarks to examine are very helpful! I'll keep you posted on any findings with the one you submitted. The hang is very strange and should not happen..

sck-at-ucy commented 4 months ago

Would be happy to run any additional benchmarks with these codes if that would be useful. I am still exploring, if I notice any worthy insights I will share.

cowolff commented 4 months ago

I now also published my code on GitHub and provided a requirements.txt for you to recreate my results: https://github.com/cowolff/ppo_mlx As mentioned earlier, in ppo_mlx.py the line 245 takes much longer than the equivalent in PyTorch:

loss, grads = loss_and_grad_fn(agent, pg_loss, ent_coef, entropy_loss, v_loss, vf_coef)

Please also note that my code still has a few other issues, resulting in my PPO implementation not yet actually learning anything. But that doesn't change the fact, that calculating the loss and gradients take very long. Thanks for your help!!