ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
17.37k stars 1.01k forks source link

[QUESTION] can we integrate direct MPSGraph features in mlx ? #1585

Open thegodone opened 1 week ago

thegodone commented 1 week ago

Describe the bug this is link to #1500 Can we reuse/plug MPSGraphGRU API directly through mlx ? MetalPerformanceShadersGraph/MPSGraphRNNOps.h

To Reproduce

Include code snippet

std::tuple<Tensor, Tensor> gru_cell_mps(
    const Tensor& input,         // x_t
    const Tensor& hidden_state,  // h_{t-1}
    const Tensor& w_ih,          // Input weight (W_z, W_r, W_h concatenated)
    const Tensor& w_hh,          // Hidden weight (U_z, U_r, U_h concatenated)
    const Tensor& b_ih,          // Bias for input weights (b_z, b_r, b_h concatenated)
    const Tensor& b_hh           // Bias for hidden weights (secondary bias if reset_after = YES)
) {
    using namespace mps;

    MPSStream* stream = getCurrentMPSStream();
    @autoreleasepool {
        MPSGraph* graph = [[MPSGraph alloc] init];

        // Define graph tensors
        MPSGraphTensor* inputTensor = [graph constantWithTensor:input];
        MPSGraphTensor* hiddenTensor = [graph constantWithTensor:hidden_state];
        MPSGraphTensor* wIhTensor = [graph constantWithTensor:w_ih];
        MPSGraphTensor* wHhTensor = [graph constantWithTensor:w_hh];
        MPSGraphTensor* bIhTensor = [graph constantWithTensor:b_ih];
        MPSGraphTensor* bHhTensor = [graph constantWithTensor:b_hh];

        // Create GRU descriptor
        MPSGraphGRUDescriptor* descriptor = [MPSGraphGRUDescriptor descriptor];
        descriptor.reverse = NO;
        descriptor.bidirectional = NO;
        descriptor.training = NO;  // No training state needed
        descriptor.resetGateFirst = YES;  // Assuming gate order is r, z, h
        descriptor.resetAfter = YES;     // Use the "reset-after" formulation

        // Apply GRU operation for a single step (T=1)
        NSArray<MPSGraphTensor*>* gruOutput = [graph GRUWithSourceTensor:inputTensor
                                                        recurrentWeight:wHhTensor
                                                            inputWeight:wIhTensor
                                                                   bias:bIhTensor
                                                              initState:hiddenTensor
                                                          secondaryBias:bHhTensor
                                                             descriptor:descriptor
                                                                   name:@"gru_cell"];

        // Extract output tensors
        MPSGraphTensor* nextHiddenState = gruOutput[0];  // h_t

        // Allocate output tensor
        Tensor outputHidden = at::empty_like(hidden_state);

        NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = [NSMutableDictionary dictionary];
        NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = [NSMutableDictionary dictionary];

        Placeholder hiddenPlaceholder(nextHiddenState, outputHidden);
        [results setObject:hiddenPlaceholder.getMPSGraphTensorData() forKey:hiddenPlaceholder.getMPSGraphTensor()];

        // Run the graph
        runMPSGraph(stream, graph, feeds, results);

        return std::make_tuple(outputHidden, outputHidden);
    }
}

looking at pytorch MPS interface, they link LSTM but not GRU https://github.com/pytorch/pytorch/blob/1886e33f6096175e6f0f77f4b44a39110d2656d6/aten/src/ATen/native/mps/operations/RnnOps.mm.

Desktop (please complete the following information):