pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.4k stars 102 forks source link

Use case suggestion: Parallel execution of batch of linear regressions on GPU #107

Open catskillsresearch opened 3 years ago

catskillsresearch commented 3 years ago

Here is a good use case, vectorizing linear regression on a sequence of X-Y pair datasets of different lengths. Goal is to ship every element of batch to GPU to run in full parallel: https://github.com/tensorflow/tensorflow/issues/51771

Chillee commented 3 years ago

The main bottleneck of this is some mechanism to allow for batched computation of ragged tensors, which we currently are not working on (although we have thought about!)

It's definitely a hard problem, unfortunately.

catskillsresearch commented 3 years ago

@Chillee, using padding out with 0s to the maximum length tensor size and some close study of the quadratic regression formula, it is possible to hand-vectorize regression for unequal problem sizes. In the end it's not rocket science but this sort of thing could definitely be made easier in platforms like Torch and Tensorflow. For more algorithmic computations the "pure" tensor and flow graph layer of both Torch and Tensorflow are not broken out in a way which clearly divorces that layer from all the machinery layered on top for machine learning. This is unfortunate because it pushes an algorithmic developer away from these frameworks and towards lower-level CUDA approaches.