Open orlev2 opened 4 months ago
Running inference of transformers4rec model on triton fails with RuntimeError: PyTorch execute failure: Expected Tensor but got GenericList.
RuntimeError: PyTorch execute failure: Expected Tensor but got GenericList
Similar issues in the past were solved by updating the forward function (see here, for example - https://github.com/triton-inference-server/server/issues/3348). However, this solution is not straightforward for transformers4rec.
merlin.systems.dag.Ensemble
traced_model = torch.jit.trace( model, model_input_dict, strict=False ) trasform_workflow = TransformWorkflow.TransformWorkflow(workflow, max_batch_size=0) pt_predict_workflow = PredictPyTorch.PredictPyTorch(traced_model, workflow.output_schema, model.output_schema) pipeline = workflow.input_schema.column_names >> trasform_workflow >> pt_predict_workflow
ensemble = Ensemble(pipeline, workflow.input_schema) ensemble.export(f"{ensemble_artifacts_repository_path}", name=model_name)
The model ensemble contains a nvtabular transformation workflow and t4rec pytorch prediction workflow, with the following structure: (list of categorical and continuous features is substantially larger than is listed here)
ensemble_artifacts_repository ├── model_name │ ├── config.pbtxt │ └── 1 │ ├── model.py │ └── ensemble │ ├── metadata.json │ └── ensemble.pkl ├── 1_predictpytorchtriton │ ├── config.pbtxt │ └── 1 │ └── model.pt └── 0_transformworkflowtriton ├── config.pbtxt └── 1 ├── model.py └── workflow ├── metadata.json ├── workflow.pkl └── categories ├── unique.feat1.parquet └── unique.feat2.parquet
2. Triton is called with
curl -X POST -H "Content-Type: application/json" -d @instances.json localhost:8000/v2/models/model_name/infer | jq -c '.[]?'
instances.json example file:
{ "id": "42", "inputs": [ { "name": "id", "shape": [ 5 ], "datatype": "BYTES", "data": [ "ABCDE12345", "ABCDE12345", "ABCDE12345", "ABCDE12345", "ABCDE12345"] ] }, { "name": "event_date", "shape": [ 5 ], "datatype": "BYTES", "data": [ "2024-02-19", "2024-02-18", "2024-02-18", "2024-03-01", "2024-03-01" ] }, { "name": "feat1", "shape": [ 5 ], "datatype": "BYTES", "data": [ "cat1", "cat2", "cat3", "cat4", "cat5" ] }, { "name": "feat2", "shape": [ 5 ], "datatype": "BYTES", "data": [ "cat1", "cat2", "cat3", "cat4", "cat5" ] } ] }
3. Feature transformation completes successfully, the above error is raised on `1_predictpytorchtriton` . ### Expected behavior Triton would return prediction output. ### Environment details - Transformers4Rec version: 23.8.00 - Platform: GCP Vertex AI - Python version: 3.10.13 - Huggingface Transformers version: 4.27.1 - PyTorch version (GPU?): 1.13.1, cuda 12.1 ### Additional context <!-- Add any other context about the problem here. --> @niraj06 @evagian
Bug description
Running inference of transformers4rec model on triton fails with
RuntimeError: PyTorch execute failure: Expected Tensor but got GenericList
.Similar issues in the past were solved by updating the forward function (see here, for example - https://github.com/triton-inference-server/server/issues/3348). However, this solution is not straightforward for transformers4rec.
Steps/Code to reproduce bug
merlin.systems.dag.Ensemble
:ensemble = Ensemble(pipeline, workflow.input_schema) ensemble.export(f"{ensemble_artifacts_repository_path}", name=model_name)
ensemble_artifacts_repository ├── model_name │ ├── config.pbtxt │ └── 1 │ ├── model.py │ └── ensemble │ ├── metadata.json │ └── ensemble.pkl ├── 1_predictpytorchtriton │ ├── config.pbtxt │ └── 1 │ └── model.pt └── 0_transformworkflowtriton ├── config.pbtxt └── 1 ├── model.py └── workflow ├── metadata.json ├── workflow.pkl └── categories ├── unique.feat1.parquet └── unique.feat2.parquet
curl -X POST -H "Content-Type: application/json" -d @instances.json localhost:8000/v2/models/model_name/infer | jq -c '.[]?'
{ "id": "42", "inputs": [ { "name": "id", "shape": [ 5 ], "datatype": "BYTES", "data": [ "ABCDE12345", "ABCDE12345", "ABCDE12345", "ABCDE12345", "ABCDE12345"] ] }, { "name": "event_date", "shape": [ 5 ], "datatype": "BYTES", "data": [ "2024-02-19", "2024-02-18", "2024-02-18", "2024-03-01", "2024-03-01" ] }, { "name": "feat1", "shape": [ 5 ], "datatype": "BYTES", "data": [ "cat1", "cat2", "cat3", "cat4", "cat5" ] }, { "name": "feat2", "shape": [ 5 ], "datatype": "BYTES", "data": [ "cat1", "cat2", "cat3", "cat4", "cat5" ] } ] }