apache / beam

Apache Beam is a unified programming model for Batch and Streaming data processing.
https://beam.apache.org/
Apache License 2.0
7.88k stars 4.26k forks source link

[Feature Request]: Allow specification of a custom model inference method for a RunInference ModelHandler #22572

Closed agvdndor closed 1 year ago

agvdndor commented 2 years ago

What would you like to happen?

The current implementation of RunInference provides model handlers for PyTorch and Sklearn models. These handlers assume that the method to call for inference is fixed:

However in some cases we want to provide a custom method for RunInference to call. Two examples:

  1. A number of pretrained models loaded with the Huggingface transformers library recommend using the generate() method. From the Huggingface docs on the T5 mode:

    At inference time, it is recommended to use generate(). This method takes care of encoding the input and feeding the encoded hidden states via cross-attention layers to the decoder and auto-regressively generates the decoder output.

    
      tokenizer = T5Tokenizer.from_pretrained("t5-small")
      model = T5ForConditionalGeneration.from_pretrained("t5-small")
    
      input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids
      outputs = model.generate(input_ids)
      print(tokenizer.decode(outputs[0], skip_special_tokens=True))
      Das Haus ist wunderbar.
  2. Using OpenAI's CLIP model which is implemented as a torch model we might not want to execute the normal forward pass to encode both images and text image_embedding, text_embedding = clip_model(image, text) but instead only compute the image embeddings image_embedding = clip_model.encode_image(image).

Solution: Allowing the user to specify the inference_fn when creating a ModelHandler would enable this usage.

Issue Priority

Priority: 2

Issue Component

Component: sdk-py-core

yeandy commented 2 years ago

This also applies to scikit-learn. For example, RandomForestClassifier has predict(X), predict_proba(X) or predict_log_proba(X), and other less common functions like apply(X), etc.

yeandy commented 2 years ago

Parent issue: https://github.com/apache/beam/issues/22117

TheNeuralBit commented 2 years ago

I think this would be difficult to do in a general (cross-ModelHandler) way as each ModelHandler is responsible for invoking it's model, and they currently have different ways of doing so.

sklearn calls a predict method: https://github.com/apache/beam/blob/5b1e1520b975de563b8b57144927894a2fddded1/sdks/python/apache_beam/ml/inference/sklearn_inference.py#L124

pytorch calls the model like a callable (which then uses the forward method IIUC?): https://github.com/apache/beam/blob/5b1e1520b975de563b8b57144927894a2fddded1/sdks/python/apache_beam/ml/inference/pytorch_inference.py#L235

I think the best we could do to solve the problem generally is establish some kind of convention.

It's also worth noting that the generate method is a property of hugging face's GenerationMixin, not a part of the torch.nn.Module API, which is in our contract: https://github.com/apache/beam/blob/5b1e1520b975de563b8b57144927894a2fddded1/sdks/python/apache_beam/ml/inference/pytorch_inference.py#L199

Is a separate generation modelhandler a better solution?

agvdndor commented 2 years ago

I could imagine three options:

  1. Stick to the current contract and assume that users will subclass the existing handlers to accommodate their model when it falls outside of the contract.
  2. Create a separate GenerationModelHandler. I'm not a fan of this approach. As @yeandy commented, there's a lot of fairly common options out there: predict_proba, apply, encode, decode, generate... So this might not scale too well and lead to a proliferation of model handlers
  3. Let the user pass the model_inference_fn during initialization as an optional kwarg.

Personally, I'd prefer option three. Something like this:

from transformers import DistilBertForSequenceClassification, DistilBertTokenizer, DistilBertConfig
from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor

model_handler = PytorchModelHandlerTensor(
class PytorchModelHandlerTensor(
    state_dict_path="<path-to-state-dict-file>",
    model_class=DistilBertForSequenceClassification,
    model_params={"config": DistilBertConfig("<path-to-config-file>"},
    model_inference_fn=DistilBertForSequenceClassification.generate)

Wyt?

yeandy commented 2 years ago

pytorch calls the model like a callable (which then uses the forward method IIUC?):

Correct.


And thanks @agvdndor for the detailed suggestions!

There are some other workarounds that users could do. Would these be sufficient solutions to this?

  1. Create a wrapper class that inherits from torch.nn.Module, and then override its forward() method and calls the model's intended inference function. (Note: this code is just an example and isn't necessarily the best or correct way to do this.)

    class Tacotron2Wrapper(torch.nn.Module):
    def __init__(self, model=tacotron2):
    super().__init__()
    self._model = model
    
    def forward(self, inputs, input_lengths):
    mel, _, _ = self._model.infer(inputs, input_lengths)
    return mel
  2. Inherit ModelHandler, and change the run_inference function to call model.infer() instead of model(). This might be easier than the first solution, but does require the user to copy the other logic correctly.

    def run_inference(
      self,
      batch: Sequence[torch.Tensor],
      model: torch.nn.Module,
      inference_args: Optional[Dict[str, Any]] = None
    ) -> Iterable[PredictionResult]:
    inference_args = {} if not inference_args else inference_args
    
    batched_tensors = torch.stack(batch)
    batched_tensors = _convert_to_device(batched_tensors, self._device)
    predictions = model.infer(batched_tensors, **inference_args)
    return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
damccorm commented 2 years ago

Generally, my take here is that we should do option 3 here and allow users to pass in a custom function. Basically:

1) This is a reasonably common pattern 2) Adding support shouldn't be too hard 3) Asking users to create their own handler (or model wrapper) any time they want to use a different method doesn't scale well. Some might contribute it back to the community, most won't, and even with those who do we're incurring an extra review/maintenance burden. It also significantly raises the bar for first time users who now would need to understand the handler internals. 4) Supporting this doesn't meaningfully make it harder for users in the simple use case (omitting this param should do nothing).

@jrmccluskey could you pick this one up when you have the bandwidth?

jrmccluskey commented 2 years ago

Looking into this a little bit, it's doable for each handler type but the end result is somewhat restrictive for the user. The provided function is going to have to take the same arguments in the same position as the current inference methods. For the given examples discussed this isn't a huge issue (unless HuggingFace users really want to use the 30+ optional generate() parameters) and will likely cover a large number of use cases, but we'll still have some advanced users who will want more tuning and will likely turn to bespoke options.

It also looks like providing the alternate inference function will need to be done at run_inference call-time, not handler init-time, since the scikit-learn and PyTorch approaches are using functions from specific instances of their respective models. Can't specify the function until you have the model, unless I'm missing something.

damccorm commented 2 years ago

The provided function is going to have to take the same arguments in the same position as the current inference methods. For the given examples discussed this isn't a huge issue (unless HuggingFace users really want to use the 30+ optional generate() parameters) and will likely cover a large number of use cases, but we'll still have some advanced users who will want more tuning and will likely turn to bespoke options.

I'm not 100% sure this is true, for example I could imagine an approach where we let users pass in some sort of function like: lambda model, batched_tensors, inference_args: model.generate(...). Regardless, I think the optional inference_args probably give users enough flexibility here, though it would be good to validate that against an existing model example.

It also looks like providing the alternate inference function will need to be done at run_inference call-time, not handler init-time, since the scikit-learn and PyTorch approaches are using functions from specific instances of their respective models. Can't specify the function until you have the model, unless I'm missing something.

You could probably do something with getattr where you pass in the function name via string, though I don't love that approach since its not very flexible w/ parameters. You could also again let them pass in a function. Its a little more work for a user, but might be worth the customizability (and for users that don't need it, their function would just be lambda model, batched_tensors, **inference_args: model.doSomething(batched_tensors, **inference_args)

Thoughts?

jrmccluskey commented 2 years ago

I've put together a brief doc discussing my perspective and preferred solution for this here - https://docs.google.com/document/d/1YYGsF20kminz7j9ifFdCD5WQwVl8aTeCo0cgPjbdFNU/edit?usp=sharing

PTAL

damccorm commented 1 year ago

@jrmccluskey could you please file a follow up issue to update our notebooks to use this feature once this is released?

jrmccluskey commented 1 year ago

Filed as #24334

damccorm commented 1 year ago

Thanks!