pytorch / serve

Serve, optimize and scale PyTorch models in production
https://pytorch.org/serve/
Apache License 2.0
4.14k stars 835 forks source link

Modularize `base_handler.py` into `handler.utils` #1631

Closed msaroufim closed 1 year ago

msaroufim commented 2 years ago

Modularize base_handler.py into handler.utils

This is a follow-up note on #1440 where we initially noted a few problems that were making it difficult to add code to TorchServe

Initially, the base_handler had a few simple responsibilities

  1. Loading a model
  2. Calling an inference
  3. Registering some simple metrics

But over time, contributors found it easy to extend functionality in torchserve by extending the handler because it involved python-only changes and we've been reluctant to fully embrace this approach since more natively integrating within the WorkloadManager.java makes it so you can gate certain functionality behind a config.properties similar to https://github.com/pytorch/serve/pull/1319. However for more complex features the code review process becomes more of a challenge and grinds to a halt https://github.com/pytorch/serve/pull/1401

At the same time, we've had examples/ grow in size and complexity like https://github.com/pytorch/serve/tree/master/examples/Huggingface_Transformers without having its improvements be easily usable by other parts of the code base

So how can we encourage users to contribute their ideas to the easy-to-extend handlers without creating a giant monolithic file? How do we encourage modularity and code reuse?

Describe the solution

Instead of designing things in the abstract, it's important to look at what contributors have actually added to handlers. Here's a non-exhaustive list

  1. Environment variables: Added support for environment variable checking to gate certain behaviors https://github.com/pytorch/serve/blob/master/ts/torch_handler/base_handler.py#L25
  2. Profiling: Added support for pytorch profiler https://github.com/pytorch/serve/blob/master/ts/torch_handler/base_handler.py#L218
  3. Model interpretability: With captum https://github.com/pytorch/serve/blob/master/ts/torch_handler/base_handler.py#L287
  4. Model optimization: With IPEX https://github.com/pytorch/serve/blob/master/ts/torch_handler/base_handler.py#L96
  5. Error handling on batch processing
  6. Utilities around pulling HuggingFace pretrained models: https://github.com/pytorch/serve/blob/master/examples/Huggingface_Transformers/Transformer_handler_generalized.py#L72
  7. Model parallelism: https://github.com/pytorch/serve/blob/master/examples/Huggingface_Transformers/Transformer_handler_generalized.py#L84
  8. Optimized preprocessing: via libraries like DALI or accimage https://github.com/pytorch/serve/pull/1545

Our proposal is simple, instead of encouraging one-off contributions in examples/ or complex contributions to the Java frontend we can create an interface for existing handler_utils/, refactor the base_handler.py and Transformer_handler_generalized.py to use them and encourage new contributors in CONTRIBUTING.md and code review to create their own.

Let's do a few examples to make this all clearer

Environment variables now

# in base_handler.py
if os.environ.get("TS_IPEX_ENABLE", "false") == "true":
    try:
        import intel_extension_for_pytorch as ipex
        ipex_enabled = True
    except ImportError as error:
        logger.warning("IPEX is enabled but intel-extension-for-pytorch is not installed. Proceeding without IPEX.")

Environment variables after

# in torch_handler.utils.env.py
env_to_package = {"TS_IPEX_ENABLE" : import intel_extension_for_pytorch as ipex,
                                "CAPTUM" : import captum}

for key, value in env_to_package.items():
  if os.environ.get(key, "false") == "true":
    try:
      value
    except ImportError as error:
      logger.warning(f"{key} is enabled but {value} is not available. Proceeding without {value}")

The difference is now we can add as many environment variables as we like instead of having tons of if conditions inside the handler. Profiling code could then be pulled out of the base_handler.py and new profilers could be added with the only constraint on them being that they write something to disk

Model optimization now

# in base_handler.py
if ipex_enabled:
    self.model = self.model.to(memory_format=torch.channels_last)
    self.model = ipex.optimize(self.model)

Model optimization after

# in torch_handler.utils.optimization
class Optimization():
  def __init__(model):
    self.model = model
  def optimize(self, model : torch.nn.Module, **kwargs) -> torch.nn.Module:
    raise NotImplementedError("This is an abstract base class, you need to call or create your own runtime")

class IPEXOptimization(Optimization):
  def optimize(self, memory_format):
        self.model = self.model.to(memory_format=torch.channels_last)
        self.model = ipex.optimize(self.model)

The difference now is we create a simpler interface for runtime providers and make it easy to add new kinds of optimizations. Examplee pruning, distillation, quantization all would also follow a similar interface where given an nn.Module produce another nn.Module. experimental/torchprep follows a similar design philosophy which makes optimizations more composable

The other application-specific code contributors have added to the base_handler.py can be solved in a similar way where instead of favoring one offs. We encourage code reuse with tiny wrapper classes. Things get particularly exciting once we start considering that TorchServe can execute an arbitrary Python script so we could for example also provide a pybind wrapper to torch::deploy, create a simple wrapper in torch_handler.utils.launch to start serving PyTorch in a multithreaded manner without touching any Java code.

Our contributors like and understand our handlers, it's time to make them a first class citizen.

EDIT: discussed this proposal with @min-jean-cho offline For IPEX a convenient workflow has been to

  1. Create a new set of configs in ConfigManager.java - for now seems OK if this file is massive we can just split up configs by comments and link to ConfigManager.java in our docs for people to learn about available features
  2. Individual properties in configs are not just data like an http_address string they also need to support a use case of
if property:
  custom_code_snippet

The custom_code_snippet can be anywhere we like

Followup on July 6

We've discussed this item more and environment variables while very convenient have a few problems

So in what follow we will discus the main issue which is "Environment variables cannot be applied to a single model". I'm reopening this discussion because we've had customers privately ask for more native ONNX support, TensorRT support and some more recent asks for MXNet support https://github.com/pytorch/serve/issues/1725

There is one alternative which would involve the backend and frontend communicating in process which @lxning pointed out

Sketch of alternative solution

We should still leverage the ConfigManager for general configurations related to TorchServe but use the MODEL_CONFIG for general model specific configurations. In particular most model optimizations are indeed model specific but things like profiling could be a global environment variable instead. So let's focus on model optimizations in what follows.

In this case we have a flag ipex_enable that gates whether IPEX is on or off and then some parameters specific to IPEX. In this we should favor a design with hierarchical configurations since it doesn't make sense to specify an ipex_dtype if ipex_enable is not enabled.

// IPEX config option that can be set at config.properties
private static final String TS_IPEX_ENABLE = "ipex_enable";
private static final String TS_IPEX_DTYPE = "ipex_dtype";
private static final String TS_IPEX_CHANNEL_LAST = "ipex_channel_last";

Regardless given some configurations set in ConfigManager in this manner we can choose to either expose them as environment variables like for example what we do with LogLocation

String logLocation = System.getenv("LOG_LOCATION");
if (logLocation != null) {
    System.setProperty("LOG_LOCATION", logLocation);
} else if (System.getProperty("LOG_LOCATION") == null) {
    System.setProperty("LOG_LOCATION", "logs");
}

Model config doesn't have a schema so any value in the JSON should be able to be used by someone downstream as it gets parsed into a hashmap

private Map<String, Map<String, JsonObject>> modelConfig = new HashMap<>();

This makes it ideal to use because noone needs to add any code to support a new parameter in the model config.properties. However the current modelConfig only accepts integer values (which is fine since we can pretend all non zero values are a bool of 1)

So maybe then instead of environment variables, if we make modelConfig available as a dictionary for any handler in the context https://github.com/pytorch/serve/blob/master/ts/context.py#L8 object then we can support new model optimizations by querying modelConfig = {...}

Those same parameters need to be made available via the model registration API so model optimization related registrations would go here https://github.com/pytorch/serve/blob/master/frontend/server/src/main/java/org/pytorch/serve/openapi/OpenApiUtils.java#L241

operation.addParameter(
        new QueryParameter(
                "batch_size", "integer", "1", "Inference batch size, default: 1."));

And here https://github.com/pytorch/serve/blob/master/frontend/server/src/main/java/org/pytorch/serve/http/messages/RegisterModelRequest.java#L10

    @SerializedName("batch_size")
    private int batchSize;

in which case can we make this request object available on the handler? https://github.com/pytorch/serve/blob/master/frontend/server/src/main/java/org/pytorch/serve/http/messages/RegisterModelRequest.java#L74-L81

    public RegisterModelRequest() {
        batchSize = 1;
        maxBatchDelay = 100;
        synchronous = true;
        initialWorkers = ConfigManager.getInstance().getConfiguredDefaultWorkersPerModel();
        responseTimeout = -1;
        s3SseKms = false;
    }

OTF message handler

After many discussions we feel like the best solution would be to extend our existing protobuf files such that people that they have support for model optimization runtimes and their settings https://github.com/pytorch/serve/blob/master/frontend/server/src/main/resources/proto/management.proto

Because settings vary by runtime provider and are hard to keep backwards compatible we can leverage a protobuf map https://developers.google.com/protocol-buffers/docs/proto3#maps to keep track of each model optimization provider

So the addition to our .proto files would be


message RegisterModelRequest {
   ... 

    // Decides whether S3 SSE KMS enabled or not, default: false.
    bool s3_sse_kms = 10; //optional

    optional map<string, string> ipex_settings = 11;
    optional map<string, string> ort_settings = 12;
    optional map<string, string> rt_settings = 13;
}

And then in the base handler we would unpack each of those maps and set up the correct configuration per the exact same design described in the first section of this doc.

One interesting aspect of using an updated protobuf format is we can actually change model optimizations at runtime instead of setting them statically so people could for example change their ONNX configuration depending on load.

lxning commented 2 years ago

This code refactoring makes backend handler elegant. The only issue I can see is more related to the communication b/w frontend and backend by using environment variable.