llvm / torch-mlir

The Torch-MLIR project aims to provide first class support from the PyTorch ecosystem to the MLIR ecosystem.
Other
1.29k stars 474 forks source link

Providing complete support for Facebook's DLRM model. #1179

Open vidsinghal opened 2 years ago

vidsinghal commented 2 years ago

Hello, I've been working on providing support for facebook's dlrm model (https://github.com/facebookresearch/dlrm) via torch-mlir. Thus far, I've been able to trace a simple path of the dlrm code similar to the sample run of the model they have here : https://github.com/facebookresearch/dlrm/blob/main/README.md#how-to-run-dlrm-code

To make it simpler to torch.jit.script or torch.jit.trace their main neural net: https://github.com/facebookresearch/dlrm/blob/97272b580a7a0d4ef98de6fdb654a442876921be/dlrm_s_pytorch.py#L195 , I have checked in this script here: https://github.com/nod-ai/SHARK/pull/185

This calls the main neural net code on a mini batch to get the the corresponding torch-script which can further be lowered down to linalg via torch_mlir.compile

Note that I had to modify the original implementation of the python code in the DLRM_Net class at multiple places in order to get TorchScript with torch.jit.trace. These changes in the original implementation would need to be incorporated in order to successfully lower it via torch mlir. In addition, this only captures the sequential_forward code which is what we are looking to support through the SHARK runtime.

The main challenge I found while working with this model is that it contains a lot of model input dependent control flow which has the potential to create a number of different paths through the program which are not completely captured by torch.jit.trace . In order to fully support this model with torch mlir it would be ideal to script this with torch.jit.script and then lower it to Linalg via torch_mlir.compile but this is not working currently due to issues with shape refinement. (back to this later)

With torch.jit.trace I've been able to trace the simple run with a custom script mentioned earlier. However to get to that point I had to implement the major embeddingBag op which this model heavily relies on. The support for the sum mode of this op which is used in this model is merged from this PR: 1.) https://github.com/llvm/torch-mlir/pull/1066 The other PR's that were incorporated are: 2.) https://github.com/llvm/torch-mlir/pull/1119 3.) https://github.com/llvm/torch-mlir/pull/1097/files 4.) https://github.com/llvm/torch-mlir/pull/862/files 5.) https://github.com/llvm/torch-mlir/issues/1075

The above trace runs through the SHARK runtime on the CPU and GPU.

After incorporating the above mentioned PR's the sample run of the code is lowerable to linalg. But this does not nearly cover the complete paths and ops required for fully supporting this model.

(back to torch.jit.script) Currently the scripted model cannot be lowered by torch mlir due to a mismatch between shape propagation of different branches. This causes an error similar to the following.

raise TorchMlirCompilerError(trimmed_message) from None
torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering TorchScript IR -> Torch Backend IR failed with the following diagnostics:
error: 'torch.prim.If' op  along control flow edge from Region #0 to parent results: source type #0 '!torch.tensor<[4,2],f32>' should match input type #0 '!torch.tensor'
note: see current operation: %76:2 = "torch.prim.If"(%75) ({
"torch.operator"(%9, %15) {name = "aten.warn"} : (!torch.str, !torch.int) -> ()
"torch.prim.If.yield"(%71, %44) : (!torch.tensor<[4,2],f32>, !torch.tensor) -> ()
}, {
"torch.prim.If.yield"(%44, %71) : (!torch.tensor, !torch.tensor<[4,2],f32>) -> ()
}) : (!torch.bool) -> (!torch.tensor, !torch.tensor)

Error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='torchscript-module-to-torch-backend-pipeline' /tmp/DLRM_Net.mlir
Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.

DLRM_Net.mlir: https://gist.github.com/vid-999/835be80b2c543b580c4a7642bac6abcd Sample code to reproduce: https://gist.github.com/vid-999/b32e648eef0565ea5928efcf27018678

This would need to be fixed in order to get the total final number of ops that are needed for this op. A partial list that I was able to get before this error occurs is:

1.) aten.element_size
2.) quantized.embedding_bag_4bit_rowwise_offsets
3.) quantized.embedding_bag_byte_rowwise_offsets
4.) prim.PythonOp
5.) aten.mse_loss
6.) aten.broadcast_tensors
7.) aten.ne.str

Hopefully this gives a good idea of where we are in fully supporting this model and what further actions are needed to be taken in order to provide support for this model via torch mlir.

silvasean commented 2 years ago

My intuition is that this model should be traceable. I looked through the code and I didn't find any examples of a non-traceable construct. Can you please link to any data-dependent or shape-dependent control flow in the model?

From glancing at the code, I think that sequential_forward is all we need to support. parallel/distributed is not in scope since they rely on multi-device/multi-node constructs which are above Torch-MLIR's layer of abstraction.

vidsinghal commented 2 years ago

Hello, By input dependent control flow I meant the inputs that are given to the model here: https://github.com/facebookresearch/dlrm/blob/97272b580a7a0d4ef98de6fdb654a442876921be/dlrm_s_pytorch.py#L893 These could all be traced but then it would require different inputs. Scripting may be easier than tracing all the possible inputs to this model.

as a small example there are 3 different branches in this create_emb function: https://github.com/facebookresearch/dlrm/blob/97272b580a7a0d4ef98de6fdb654a442876921be/dlrm_s_pytorch.py#L236 which are triggered from different inputs to the model.

silvasean commented 2 years ago

Those are just hyperparameters. Hyperparameters do not change during program execution and are expected to require different traced/scripted program captures per set of hyperparameters.

vidsinghal commented 2 years ago

Those are just hyperparameters. Hyperparameters do not change during program execution and are expected to require different traced/scripted program captures per set of hyperparameters.

yes I agree. But I thought you can get them all together just by scripting the model in one go. If scripting is not what you recommend then we would need to exhaustively get multiple different inputs (hyper parameters) to get traces of all the execution paths which is what we should work on next.

silvasean commented 2 years ago

Those are just hyperparameters. Hyperparameters do not change during program execution and are expected to require different traced/scripted program captures per set of hyperparameters.

yes I agree. But I thought you can get them all together just by scripting the model in one go. If scripting is not what you recommend then we would need to exhaustively get multiple different inputs (hyper parameters) to get traces of all the execution paths which is what we should work on next.

We will need to find what hyperparameters are relevant to users are compile in those configurations. It is like testing ResNet18 and ResNet50. They are just hyperparameter variations.

vidsinghal commented 2 years ago

Those are just hyperparameters. Hyperparameters do not change during program execution and are expected to require different traced/scripted program captures per set of hyperparameters.

yes I agree. But I thought you can get them all together just by scripting the model in one go. If scripting is not what you recommend then we would need to exhaustively get multiple different inputs (hyper parameters) to get traces of all the execution paths which is what we should work on next.

We will need to find what hyperparameters are relevant to users are compile in those configurations. It is like testing ResNet18 and ResNet50. They are just hyperparameter variations.

makes sense. In that regards most ops are supported then since I've already checked in the traced path for the simple run which is what I assume most users will end up using.

silvasean commented 2 years ago

@powderluv is there any further action item here or set of hyperparameters we need to cover before calling DLRM "done"?

silvasean commented 1 year ago

@powderluv anything else to do here?