bentoml / BentoML

The easiest way to serve AI/ML models in production - Build Model Inference Service, LLM APIs, Multi-model Inference Graph/Pipelines, LLM/RAG apps, and more!
https://bentoml.com
Apache License 2.0
6.79k stars 765 forks source link

bug: torch tensor converted to numpy array during serialization & non-writable warning #4839

Open rlleshi opened 4 days ago

rlleshi commented 4 days ago

Describe the bug

I'm passing a torch tensor to another bentoml service but it appears to be first converted into a numpy array during serialization.

This is then the cause of the following warning:

/home/user/.pyenv/versions/3.10.0/envs/tensor_np_conversion_bug/lib/python3.10/site-packages/_bentoml_sdk/validators.py:276: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)
  return torch.from_numpy(obj)

To reproduce

Minimal reproducible example:

service.py

from __future__ import annotations

import torch
import bentoml

@bentoml.service()
class TestService:

    def __init__(self) -> None:
        pass

    @bentoml.api
    def test(self, image: torch.Tensor) -> torch.Tensor:
        print(f'Type inside TestService.test: {type(image)}')
        return image

service2.py

from __future__ import annotations

import torch
import bentoml
import cv2
from .service import TestService

@bentoml.service()
class TestService2:

    test_service = bentoml.depends(TestService)

    def __init__(self) -> None:
        pass

    @bentoml.api
    def test2(self, image_path: str) -> str:
        img = cv2.imread(image_path)
        print('Writeable ', img.flags.writeable)
        img_tensor = torch.from_numpy(img).permute(2, 0, 1)
        print(f'Type inside TestService2.test2 (before call): {type(img_tensor)}')
        result = self.test_service.test(img_tensor)
        print(f'Type inside TestService2.test2 (after call): {type(result)}')
        return 'OK'

bentofile.yaml

service: "src.service2:TestService2"
include:
  - "*.py"
python:
  requirements_txt: requirements.txt

requirements.txt

bentoml
opencv-python
torch

Here's the logs:

❯ bentoml serve . 2024-07-01T19:52:45+0200 [INFO] [cli] Starting production HTTP BentoServer from "src.service2:TestService2" listening on http://localhost:3000 (Press CTRL+C to quit) Writeable True Type inside TestService2.test2 (before call): <class 'torch.Tensor'> <class 'torch.Tensor'>, torch.Size([3, 958, 640]), File: /home/user/.pyenv/versions/3.10.0/envs/tensor_np_conversion_bug/lib/python3.10/site-packages/pydantic/main.py, Line: 176 <class 'numpy.ndarray'>, (3, 958, 640), File: /home/user/.pyenv/versions/3.10.0/envs/tensor_np_conversion_bug/lib/python3.10/site-packages/pydantic/main.py, Line: 551 Type inside TestService.test: <class 'torch.Tensor'> <class 'torch.Tensor'>, torch.Size([3, 958, 640]), File: /home/user/.pyenv/versions/3.10.0/envs/tensor_np_conversion_bug/lib/python3.10/site-packages/pydantic/rootmodel.py, Line: 71 2024-07-01T19:52:47+0200 [INFO] [service:TestService:1] (scheme=http,method=POST,path=/test,type=application/vnd.bentoml+pickle,length=1839537) (status=200,type=application/json,length=1839484) 13.497ms (trace=7ad7af6d592d6d9d5bad7ac796537c24,span=5920153c7d24fc1a,sampled=0,service.name=TestService) <class 'numpy.ndarray'>, (3, 958, 640), File: /home/user/.pyenv/versions/3.10.0/envs/tensor_np_conversion_bug/lib/python3.10/site-packages/pydantic/main.py, Line: 551 /home/user/.pyenv/versions/3.10.0/envs/tensor_np_conversion_bug/lib/python3.10/site-packages/_bentoml_sdk/validators.py:276: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.) return torch.from_numpy(obj) Type inside TestService2.test2 (after call): <class 'torch.Tensor'> 2024-07-01T19:52:47+0200 [INFO] [entry_service:TestService2:1] 127.0.0.1:41972 (scheme=http,method=POST,path=/test2,type=application/json,length=120) (status=200,type=text/plain; charset=utf-8,length=2) 72.050ms (trace=bf613730b60be2b0d0d8cd6da854c21b,span=543b6aa1b67d6475,sampled=0,service.name=TestService2)

Sample curl request:

curl --location --request POST 'http://localhost:3000/test2' \
--header 'Content-Type: application/json' \
--data-raw '{
    "image_path": "path_to_image"
}'

Expected behavior

Perhaps I'm still not super familiar with the documentation as I just started using bentoml. But I had thought that it supported sending tensors.

Also, if they have to be converted to numpy first, they should be writable when converted back to tensor, no?

Environment

Environment variable

BENTOML_DEBUG=''
BENTOML_QUIET=''
BENTOML_BUNDLE_LOCAL_BUILD=''
BENTOML_DO_NOT_TRACK=''
BENTOML_CONFIG=''
BENTOML_CONFIG_OPTIONS=''
BENTOML_PORT=''
BENTOML_HOST=''
BENTOML_API_WORKERS=''

System information

bentoml: 1.2.19 python: 3.10.0 platform: Linux-6.5.0-42-generic-x86_64-with-glibc2.38 uid_gid: 1000:1000

pip_packages
``` aiohttp==3.9.5 aiosignal==1.3.1 annotated-types==0.7.0 anyio==4.4.0 appdirs==1.4.4 asgiref==3.8.1 async-timeout==4.0.3 attrs==23.2.0 bentoml==1.2.19 build==1.2.1 cattrs==23.1.2 certifi==2024.6.2 circus==0.18.0 click==8.1.7 click-option-group==0.5.6 cloudpickle==3.0.0 deepmerge==1.1.1 Deprecated==1.2.14 exceptiongroup==1.2.1 filelock==3.15.4 frozenlist==1.4.1 fs==2.4.16 fsspec==2024.6.1 h11==0.14.0 httpcore==1.0.5 httpx==0.27.0 httpx-ws==0.6.0 idna==3.7 importlib-metadata==6.11.0 inflection==0.5.1 Jinja2==3.1.4 markdown-it-py==3.0.0 MarkupSafe==2.1.5 mdurl==0.1.2 mpmath==1.3.0 multidict==6.0.5 networkx==3.3 numpy==2.0.0 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==8.9.2.26 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-ml-py==11.525.150 nvidia-nccl-cu12==2.20.5 nvidia-nvjitlink-cu12==12.5.40 nvidia-nvtx-cu12==12.1.105 opencv-python==4.10.0.84 opentelemetry-api==1.20.0 opentelemetry-instrumentation==0.41b0 opentelemetry-instrumentation-aiohttp-client==0.41b0 opentelemetry-instrumentation-asgi==0.41b0 opentelemetry-sdk==1.20.0 opentelemetry-semantic-conventions==0.41b0 opentelemetry-util-http==0.41b0 packaging==24.1 pathspec==0.12.1 pip-requirements-parser==32.0.1 pip-tools==7.4.1 prometheus_client==0.20.0 psutil==6.0.0 pydantic==2.7.4 pydantic_core==2.18.4 Pygments==2.18.0 pyparsing==3.1.2 pyproject_hooks==1.1.0 python-dateutil==2.9.0.post0 python-json-logger==2.0.7 python-multipart==0.0.9 PyYAML==6.0.1 pyzmq==26.0.3 rich==13.7.1 schema==0.7.7 simple-di==0.1.5 six==1.16.0 sniffio==1.3.1 starlette==0.37.2 sympy==1.12.1 tomli==2.0.1 tomli_w==1.0.0 torch==2.3.1 tornado==6.4.1 triton==2.3.1 typing_extensions==4.12.2 uvicorn==0.30.1 watchfiles==0.22.0 wrapt==1.16.0 wsproto==1.2.0 yarl==1.9.4 zipp==3.19.2 ```
frostming commented 3 days ago

Also, if they have to be converted to numpy first, they should be writable when converted back to tensor, no?

No, the underlying data are mapped to a block of memory to avoid copy during (de)serialization.

We'll try to suppress the warning for users, but for now you can just ignore it and don't write data to the tensor, or copy the tensor yourself, if you do want to write to it.

rlleshi commented 3 days ago

Thanks for the quick response!

I see. Could bentoml then preserve the state of the writeable flag under the hood perhaps? I have a bunch of services and would have to take care to do this everywhere as the tensor flows from one service to another.