Closed agvdndor closed 1 year ago
This also applies to scikit-learn. For example, RandomForestClassifier has predict(X)
, predict_proba(X)
or predict_log_proba(X)
, and other less common functions like apply(X)
, etc.
Parent issue: https://github.com/apache/beam/issues/22117
I think this would be difficult to do in a general (cross-ModelHandler) way as each ModelHandler is responsible for invoking it's model, and they currently have different ways of doing so.
sklearn calls a predict method: https://github.com/apache/beam/blob/5b1e1520b975de563b8b57144927894a2fddded1/sdks/python/apache_beam/ml/inference/sklearn_inference.py#L124
pytorch calls the model like a callable (which then uses the forward method IIUC?): https://github.com/apache/beam/blob/5b1e1520b975de563b8b57144927894a2fddded1/sdks/python/apache_beam/ml/inference/pytorch_inference.py#L235
I think the best we could do to solve the problem generally is establish some kind of convention.
It's also worth noting that the generate
method is a property of hugging face's GenerationMixin
, not a part of the torch.nn.Module
API, which is in our contract: https://github.com/apache/beam/blob/5b1e1520b975de563b8b57144927894a2fddded1/sdks/python/apache_beam/ml/inference/pytorch_inference.py#L199
Is a separate generation modelhandler a better solution?
I could imagine three options:
predict_proba
, apply
, encode
, decode
, generate
... So this might not scale too well and lead to a proliferation of model handlersPersonally, I'd prefer option three. Something like this:
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer, DistilBertConfig
from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor
model_handler = PytorchModelHandlerTensor(
class PytorchModelHandlerTensor(
state_dict_path="<path-to-state-dict-file>",
model_class=DistilBertForSequenceClassification,
model_params={"config": DistilBertConfig("<path-to-config-file>"},
model_inference_fn=DistilBertForSequenceClassification.generate)
Wyt?
pytorch calls the model like a callable (which then uses the forward method IIUC?):
Correct.
And thanks @agvdndor for the detailed suggestions!
GenerationModelHandler
I agree that it does not scale well.model_inference_fn
could work. The change itself shouldn't be that hard to implement. However, we need to ask ourselves -- at what point are we doing too much to address these custom use cases? On the one hand, I recognize that HuggingFace is very popular, if I'd be remiss if to see a bunch of potential RunInference users turned away because of how difficult it is to plug in a HuggingFace model into PytorchModelHandlerTensor
. On the other hand, if we can capture 80% of use cases without having this custom infer function, that might be good enough? If users do require a more tailored solution, then they probably should be writing up their own DoFn
anyway (inspired, of course, by our own implementation). @robertwb What are your thoughts on adding something like a Generation ModleHandler versus a model_inference_fn
?There are some other workarounds that users could do. Would these be sufficient solutions to this?
Create a wrapper class that inherits from torch.nn.Module
, and then override its forward()
method and calls the model's intended inference function. (Note: this code is just an example and isn't necessarily the best or correct way to do this.)
class Tacotron2Wrapper(torch.nn.Module):
def __init__(self, model=tacotron2):
super().__init__()
self._model = model
def forward(self, inputs, input_lengths):
mel, _, _ = self._model.infer(inputs, input_lengths)
return mel
Inherit ModelHandler
, and change the run_inference
function to call model.infer()
instead of model()
. This might be easier than the first solution, but does require the user to copy the other logic correctly.
def run_inference(
self,
batch: Sequence[torch.Tensor],
model: torch.nn.Module,
inference_args: Optional[Dict[str, Any]] = None
) -> Iterable[PredictionResult]:
inference_args = {} if not inference_args else inference_args
batched_tensors = torch.stack(batch)
batched_tensors = _convert_to_device(batched_tensors, self._device)
predictions = model.infer(batched_tensors, **inference_args)
return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
Generally, my take here is that we should do option 3 here and allow users to pass in a custom function. Basically:
1) This is a reasonably common pattern 2) Adding support shouldn't be too hard 3) Asking users to create their own handler (or model wrapper) any time they want to use a different method doesn't scale well. Some might contribute it back to the community, most won't, and even with those who do we're incurring an extra review/maintenance burden. It also significantly raises the bar for first time users who now would need to understand the handler internals. 4) Supporting this doesn't meaningfully make it harder for users in the simple use case (omitting this param should do nothing).
@jrmccluskey could you pick this one up when you have the bandwidth?
Looking into this a little bit, it's doable for each handler type but the end result is somewhat restrictive for the user. The provided function is going to have to take the same arguments in the same position as the current inference methods. For the given examples discussed this isn't a huge issue (unless HuggingFace users really want to use the 30+ optional generate()
parameters) and will likely cover a large number of use cases, but we'll still have some advanced users who will want more tuning and will likely turn to bespoke options.
It also looks like providing the alternate inference function will need to be done at run_inference call-time, not handler init-time, since the scikit-learn and PyTorch approaches are using functions from specific instances of their respective models. Can't specify the function until you have the model, unless I'm missing something.
The provided function is going to have to take the same arguments in the same position as the current inference methods. For the given examples discussed this isn't a huge issue (unless HuggingFace users really want to use the 30+ optional generate() parameters) and will likely cover a large number of use cases, but we'll still have some advanced users who will want more tuning and will likely turn to bespoke options.
I'm not 100% sure this is true, for example I could imagine an approach where we let users pass in some sort of function like:
lambda model, batched_tensors, inference_args: model.generate(...)
. Regardless, I think the optional inference_args
probably give users enough flexibility here, though it would be good to validate that against an existing model example.
It also looks like providing the alternate inference function will need to be done at run_inference call-time, not handler init-time, since the scikit-learn and PyTorch approaches are using functions from specific instances of their respective models. Can't specify the function until you have the model, unless I'm missing something.
You could probably do something with getattr
where you pass in the function name via string, though I don't love that approach since its not very flexible w/ parameters. You could also again let them pass in a function. Its a little more work for a user, but might be worth the customizability (and for users that don't need it, their function would just be lambda model, batched_tensors, **inference_args: model.doSomething(batched_tensors, **inference_args)
Thoughts?
I've put together a brief doc discussing my perspective and preferred solution for this here - https://docs.google.com/document/d/1YYGsF20kminz7j9ifFdCD5WQwVl8aTeCo0cgPjbdFNU/edit?usp=sharing
PTAL
@jrmccluskey could you please file a follow up issue to update our notebooks to use this feature once this is released?
Filed as #24334
Thanks!
What would you like to happen?
The current implementation of RunInference provides model handlers for PyTorch and Sklearn models. These handlers assume that the method to call for inference is fixed:
__call__
method ->output = torch_model(input)
predict
method ->output = sklearn_model.predict(input)
However in some cases we want to provide a custom method for RunInference to call. Two examples:
A number of pretrained models loaded with the Huggingface transformers library recommend using the
generate()
method. From the Huggingface docs on the T5 mode:Using OpenAI's CLIP model which is implemented as a torch model we might not want to execute the normal forward pass to encode both images and text
image_embedding, text_embedding = clip_model(image, text)
but instead only compute the image embeddingsimage_embedding = clip_model.encode_image(image)
.Solution: Allowing the user to specify the
inference_fn
when creating a ModelHandler would enable this usage.Issue Priority
Priority: 2
Issue Component
Component: sdk-py-core