sql-machine-learning / sqlflow

Brings SQL and AI together.
Apache License 2.0
5.07k stars 698 forks source link

[Investigation] Preprocessing in PyTorch #2276

Open brightcoder01 opened 4 years ago

brightcoder01 commented 4 years ago

In this issue, we will make investigation on the data preprocessing solution for native PyTorch and fast.ai (High level api library built upon PyTorch and pandas).


fast.ai is a library built upon PyTorch and provides high level apis to simplify the user work of building model. It provides the following transform functions for tabular data.

Transform Functions

Categorify: Build the vocabulary and convert the categorical feature into zero based id. FillMissing: For each numerical column, fill the N/A item with the median/most-common/user-specified value of this column. Normalize: y = x - mean(x) / std(x)

All the functions above are built upon pandas.

Analysis and Transform in Training

Each transform function above has a method apply_train. It will do analysis on the input training data and get the statistical results. And then do transformation on the training data per data instance.

Dataset size: The input train data is pandas data frame. It means that the training data should be loaded into one single process and execute the analysis work. The dataset size can't be large.

Transform in Prediction/Evaluation

The transform function provides a method apply_test. It will execute transform on the input data per data instance using the statistical results from apply_train as the transformation parameters.

Transform in Serving

For high-performance serving of PyTorch model, we will choose the following three options. Please check more details in #2399

Problem: Because the transform functions above are built upon pandas, they can't be serialized into TorchScript or converted into ONNX format. So this preprocessing logic can't be saved together with the model.

PyTorch Native

Other Reference Materials

Deep Learning for Tabular Data using PyTorch: Use sklearn api to do the data preprocessing. This preprocess logic can't be converted into TorchScript or ONNX

brightcoder01 commented 4 years ago

The key challenge is: how to save the preprocess logic into the serialized model for serving (TorchScript or ONNX).

Proposal Options:

sneaxiy commented 4 years ago

As far as I know, PyTorch does not support string-type Tensor. The following codes would raise error in PyTorch. Maybe it is hard to define PyTorch operators to support string column.

import torch
import numpy

a = numpy.array(['apples', 'foobar', 'cowboy'])
t = torch.Tensor(a)

Error message:

Traceback (most recent call last):
  File "test_pytorch.py", line 5, in <module>
    t = torch.Tensor(a)
TypeError: can't convert np.ndarray of type numpy.str_. The only supported types are: float64, float32, float16, int64, int32, int16, int8, uint8, and bool.
workingloong commented 4 years ago

The key challenge is: how to save the preprocess logic into the serialized model for serving (TorchScript or ONNX).

Proposal Options:

  • Develop some custom some PyTorch OP to do the preprocessing. The OPs can be saved into TorchScript or ONNX. We need cover the transform functions in the list.
  • Develop a feature engineering library (not based on PyTorch OP). The UI for this library is configuration driven or python. For model training/evaluation, we will preprocess the source table and write the result into a temp table, and then execute the training loop based on the temp table. For the model serving, we will use the library configuration and TorchScript/ONNX together.

Using the 2nd solution, the feature configuration is separated from the torch model. How can users combine them together for serving.

brightcoder01 commented 4 years ago

Additional Options:

typhoonzero commented 4 years ago

Can we write libtorch dataset transform functions to achieve this?

brightcoder01 commented 4 years ago

Can we write libtorch dataset transform functions to achieve this?

I'm afraid that the transform logic in Dataset cannot be saved to model and used in serving.