Open thegodone opened 1 week ago
Describe the bug this is link to #1500 Can we reuse/plug MPSGraphGRU API directly through mlx ? MetalPerformanceShadersGraph/MPSGraphRNNOps.h
To Reproduce
Include code snippet
std::tuple<Tensor, Tensor> gru_cell_mps( const Tensor& input, // x_t const Tensor& hidden_state, // h_{t-1} const Tensor& w_ih, // Input weight (W_z, W_r, W_h concatenated) const Tensor& w_hh, // Hidden weight (U_z, U_r, U_h concatenated) const Tensor& b_ih, // Bias for input weights (b_z, b_r, b_h concatenated) const Tensor& b_hh // Bias for hidden weights (secondary bias if reset_after = YES) ) { using namespace mps; MPSStream* stream = getCurrentMPSStream(); @autoreleasepool { MPSGraph* graph = [[MPSGraph alloc] init]; // Define graph tensors MPSGraphTensor* inputTensor = [graph constantWithTensor:input]; MPSGraphTensor* hiddenTensor = [graph constantWithTensor:hidden_state]; MPSGraphTensor* wIhTensor = [graph constantWithTensor:w_ih]; MPSGraphTensor* wHhTensor = [graph constantWithTensor:w_hh]; MPSGraphTensor* bIhTensor = [graph constantWithTensor:b_ih]; MPSGraphTensor* bHhTensor = [graph constantWithTensor:b_hh]; // Create GRU descriptor MPSGraphGRUDescriptor* descriptor = [MPSGraphGRUDescriptor descriptor]; descriptor.reverse = NO; descriptor.bidirectional = NO; descriptor.training = NO; // No training state needed descriptor.resetGateFirst = YES; // Assuming gate order is r, z, h descriptor.resetAfter = YES; // Use the "reset-after" formulation // Apply GRU operation for a single step (T=1) NSArray<MPSGraphTensor*>* gruOutput = [graph GRUWithSourceTensor:inputTensor recurrentWeight:wHhTensor inputWeight:wIhTensor bias:bIhTensor initState:hiddenTensor secondaryBias:bHhTensor descriptor:descriptor name:@"gru_cell"]; // Extract output tensors MPSGraphTensor* nextHiddenState = gruOutput[0]; // h_t // Allocate output tensor Tensor outputHidden = at::empty_like(hidden_state); NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = [NSMutableDictionary dictionary]; NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = [NSMutableDictionary dictionary]; Placeholder hiddenPlaceholder(nextHiddenState, outputHidden); [results setObject:hiddenPlaceholder.getMPSGraphTensorData() forKey:hiddenPlaceholder.getMPSGraphTensor()]; // Run the graph runMPSGraph(stream, graph, feeds, results); return std::make_tuple(outputHidden, outputHidden); } }
looking at pytorch MPS interface, they link LSTM but not GRU https://github.com/pytorch/pytorch/blob/1886e33f6096175e6f0f77f4b44a39110d2656d6/aten/src/ATen/native/mps/operations/RnnOps.mm.
Desktop (please complete the following information):
Describe the bug this is link to #1500 Can we reuse/plug MPSGraphGRU API directly through mlx ? MetalPerformanceShadersGraph/MPSGraphRNNOps.h
To Reproduce
Include code snippet
looking at pytorch MPS interface, they link LSTM but not GRU https://github.com/pytorch/pytorch/blob/1886e33f6096175e6f0f77f4b44a39110d2656d6/aten/src/ATen/native/mps/operations/RnnOps.mm.
Desktop (please complete the following information):