Open bdnorman opened 2 years ago
The error is correct. Lists of tensors are not supported as inputs currently (that model's documentation confirms that the input is a list of tensors as well).
If you'd like, I can file a feature request.
@dyastremsky a new feature would be great!
Also for anyone who comes across this before this a new feature, a hack around this that seems to work to avoid listed tensors:
class WrappedFasterRCNNModel(torch.nn.Module):
def __init__(self, fasterrcnn_model):
super().__init__()
self.rcnn = fasterrcnn_model
def forward(self, x):
assert isinstance(x, torch.Tensor)
out = self.rcnn([x])
out = out[0]
return out['boxes'], out['labels'], out['scores']
Happy you found a workaround, thanks for sharing it! I've filed a ticket to look into adding this feature.
@dyastremsky a new feature would be great!
Also for anyone who comes across this before this a new feature, a hack around this that seems to work to avoid listed tensors:
class WrappedFasterRCNNModel(torch.nn.Module): def __init__(self, fasterrcnn_model): super().__init__() self.rcnn = fasterrcnn_model def forward(self, x): assert isinstance(x, torch.Tensor) out = self.rcnn([x]) out = out[0] return out['boxes'], out['labels'], out['scores']
Thank you for your contribution. I managed to solve it like this. In any case, would be great that Triton allows a list of tensors. Sometimes I want to return list of tensors in forward method. Although after doing the inference I can also convert it to a list of tensors it is quite awkward.
Description When trying to load my pytorch Triton model I am receiving
Internal: An input of type 'Tensor[]' was detected in the model. Only a single input of type Dict(str, Tensor) or input(s) of type Tensor are supported.
.Triton Information nvcr.io/nvidia/tritonserver:22.04-py3
Are you using the Triton container or did you build it yourself? Using triton container
To Reproduce I am using the torchvision fasterrcnn model and converting to torchscript
My config.pbtxt file is
Expected behavior Triton server to run. I tested with
torchvision.models.resnet18
and it works fine