ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
17.12k stars 990 forks source link

Compile behavior #712

Closed sck-at-ucy closed 8 months ago

sck-at-ucy commented 8 months ago

I have two slightly different implementations of model training with compile in my code, one works the other fails and I do not understand the cause. Data_loader is a function that selects datasets for training and validation from precomputed solutions. It includes an option for random shuffling but I turned it off and made no difference,

1. Successful Implementation

def loss_fn_2D(model, src, target, alpha, dx, dy, dt, ny, nx):
    # Weights for physics-informed and BC-departure losses
    boundary_loss_weight = 0.005
    physics_loss_weight = 0.001

    # Direct model output and target
    model_output_flat = model(src)  # Assuming model(src) returns flattened 2D grid predictions
    target_flat = target  # Assuming target is already in the correct flattened format

    # MSE Loss
    mse_loss = nn.losses.mse_loss(model_output_flat, target_flat, reduction='mean')

    # Physics-informed Loss (adapted for 2D)
    pi_loss = physics_informed_loss_2D(model_output_flat, alpha, dx, dy, dt, ny, nx)

    # Boundary Loss (adapted for 2D)
    boundary_loss = compute_boundary_loss_2D(model_output_flat, target_flat, dx, dy, dt, ny, nx)

    # Combine MSE loss, Physics-informed loss, and Boundary loss
    total_loss = mse_loss + boundary_loss_weight * boundary_loss + physics_loss_weight * pi_loss

    return total_loss

# The training loop and supporting functions to perform gradient descent using the Adam optimizer.
# The state that will be captured as input and output

@partial(mx.compile, inputs=state, outputs=state)
def evaluate_step(src, target, alpha, dx, dy, dt, ny, nx):
    # Direct loss calculation without reshaping
    loss = loss_fn_2D(model, src, target, alpha, dx, dy, dt, ny, nx)
    return loss

@partial(mx.compile)
def train_and_validate(train_data, validation_data, batch_size, epochs, alpha, dx, dy, dt, ny, nx):
    tic = time.perf_counter()
    for epoch in range(epochs):
        mx.eval(state)
        total_train_loss = 0
        num_train_batches = 0

        for src, target in data_loader_2D(train_data, batch_size):
            loss_and_grad_fn = nn.value_and_grad(model, loss_fn_2D)
            loss, grads = loss_and_grad_fn(model, src, target, alpha, dx, dy, dt, ny, nx)

            # Update model parameters
            optimizer.update(model, grads)
            mx.eval(state)
            total_train_loss += loss.item()
            num_train_batches += 1

        total_val_loss = 0
        num_val_batches = 0

        #  loop for validation phase
        for src, target in data_loader_2D(validation_data, batch_size, shuffle=False):
            val_loss = evaluate_step(src, target, alpha, dx, dy, dt, ny, nx)
            total_val_loss += val_loss.item()
            num_val_batches += 1

        # Print epoch summary
        print(
            f'Epoch {epoch + 1}, Training Loss: {total_train_loss / num_train_batches}, Validation Loss: {total_val_loss / num_val_batches}')
    toc = time.perf_counter()
    tpi = 1e3 * (toc - tic) / 5
    print(f"Time per iteration {tpi:.3f} (ms)")

2. Failing Implementation

In the second implementation that fails, the value_and_grad can is moved into a train_step() function but is otherwise the same.

@partial(mx.compile, inputs=state, outputs=state)
def train_step(src, target, alpha, dx, dy, dt, ny, nx):
    # Calculate loss and gradients
    loss_and_grad_fn = nn.value_and_grad(model, loss_fn_2D)
    loss, grads = loss_and_grad_fn(model, src, target, alpha, dx, dy, dt, ny, nx)

    # Update model parameters
    optimizer.update(model, grads)
    return loss

@partial(mx.compile, inputs=state, outputs=state)
def evaluate_step(src, target, alpha, dx, dy, dt, ny, nx):
    # Direct loss calculation without reshaping
    loss = loss_fn_2D(model, src, target, alpha, dx, dy, dt, ny, nx)
    return loss

@partial(mx.compile)
def train_and_validate(train_data, validation_data, batch_size, epochs, alpha, dx, dy, dt, ny, nx):
    tic = time.perf_counter()
    for epoch in range(epochs):
        mx.eval(state)
        total_train_loss = 0
        num_train_batches = 0

        for src, target in data_loader_2D(train_data, batch_size):
            loss  = train_step(src, target, alpha, dx, dy, dt, ny, nx)
            mx.eval(state)
            total_train_loss += loss.item()
            num_train_batches += 1

        total_val_loss = 0
        num_val_batches = 0

        #  loop for validation phase
        for src, target in data_loader_2D(validation_data, batch_size, shuffle=False):
            val_loss = evaluate_step(src, target, alpha, dx, dy, dt, ny, nx)
            total_val_loss += val_loss.item()
            num_val_batches += 1

        # Print epoch summary
        print(
            f'Epoch {epoch + 1}, Training Loss: {total_train_loss / num_train_batches}, Validation Loss: {total_val_loss / num_val_batches}')
    toc = time.perf_counter()
    tpi = 1e3 * (toc - tic) / 5
    print(f"Time per iteration {tpi:.3f} (ms)")

This code with separate train_step fails with IndexError: unordered_map::at: key not found

Traceback (most recent call last):
  File "/Users/m2/PycharmProjects/pythonProject_StreamLit/Transformer_2D_HeatEqn_NewStart_v2_compilefailed.py", line 684, in <module>
    train_and_validate(training_data, validation_data, batch_size, epochs, alpha, dx, dy, dt, ny, nx)
  File "/Users/m2/PycharmProjects/pythonProject_StreamLit/Transformer_2D_HeatEqn_NewStart_v2_compilefailed.py", line 660, in train_and_validate
    loss  = train_step(src, target, alpha, dx, dy, dt, ny, nx)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
IndexError: unordered_map::at: key not found
awni commented 8 months ago

For both you should not compile the outer training loop which itself contains mx.eval. E.g. this:

@partial(mx.compile)
def train_and_validate(train_data, validation_data, batch_size, epochs, alpha, dx, dy, dt, ny, nx):

should not be compiled. You cannot evaluate the graph inside a compiled function so that is almost always going to crash.

If you remove that compiling the train_step and eval_step should work as you have them, assuming you do not do any evals inside those functions (e.g. by casting to Numpy or calling `mx.eval)

sck-at-ucy commented 8 months ago

Indeed, this was the problem, thank you!