pytorch / serve

Serve, optimize and scale PyTorch models in production
https://pytorch.org/serve/
Apache License 2.0
4.23k stars 864 forks source link

Multi request batching #1235

Open toretak opened 3 years ago

toretak commented 3 years ago

I'v been looking in previous issues, but I could not find satisfying answer.

I have packed model using model-archiver in docker.

docker run --rm -it --name mar -v $(pwd)/output:/output -v \
$(pwd)/model:/model -v $(pwd)/src/:/src pytorch/torchserve:latest \
torch-model-archiver --model-name u2net --version ${MODEL_VERSION:-'1.0'} \
--model-file /src/u2net.py \
--serialized-file /model/u2net.pth --export-path /output \
--extra-files /src/unet_classes.py --handler /src/custom_handler.py

Than I run model in docker.

docker run --rm -it -v $(pwd)/output:/home/model-server/model-store \
-v $(pwd)/config.properties:/tmp/config.properties \
-p 8080:8080 -p 8081:8081 -p 8082:8082 pytorch/torchserve:latest \
torchserve --start --model-store model-store --ts-config /tmp/config.properties
Python executable: /usr/bin/python3
Config file: /tmp/config.properties
Inference address: http://0.0.0.0:8080
Management address: http://0.0.0.0:8081
Metrics address: http://0.0.0.0:8082
Model Store: /home/model-server/model-store
Initial Models: u2net.mar
Log dir: /home/model-server/logs
Metrics dir: /home/model-server/logs
Netty threads: 0
Netty client threads: 0
Default workers per model: 12
Blacklist Regex: N/A
Maximum Response Size: 6553500
Maximum Request Size: 6553500
Prefer direct buffer: false
Allowed Urls: [file://.*|http(s)?://.*]
Custom python dependency for model allowed: false
Metrics report format: prometheus
Enable metrics API: true
Workflow Store: /home/model-server/model-store
Model config: {"u2net": {"1.0": {"defaultVersion": true,"marName": "u2net.mar","minWorkers": 1,"maxWorkers": 1,"batchSize": 8,"maxBatchDelay": 250,"responseTimeout": 120}},"u2netp": {"1.0": {"defaultVersion": true,"marName": "u2netp.mar","minWorkers": 1,"maxWorkers": 1,"batchSize": 8,"maxBatchDelay": 250,"responseTimeout": 120}}}
2021-09-03 09:10:33,042 [INFO ] main org.pytorch.serve.servingsdk.impl.PluginsManager -  Loading snapshot serializer plugin...
2021-09-03 09:10:33,044 [INFO ] main org.pytorch.serve.ModelServer - Loading initial models: u2net.mar
2021-09-03 09:10:35,220 [DEBUG] main org.pytorch.serve.wlm.ModelVersionedRefs - Adding new version 1.0 for model u2net
2021-09-03 09:10:35,220 [DEBUG] main org.pytorch.serve.wlm.ModelVersionedRefs - Setting default version to 1.0 for model u2net
2021-09-03 09:10:35,221 [INFO ] main org.pytorch.serve.wlm.ModelManager - Model u2net loaded.
2021-09-03 09:10:35,221 [DEBUG] main org.pytorch.serve.wlm.ModelManager - updateModel: u2net, count: 1

Than I will call model multiple times

curl -X POST http://127.0.0.1:8080/predictions/u2net -T "{bike.jpg,boat.jpg,horse.jpg}"

or from python

import aiohttp
import asyncio
import glob

images = glob.glob('test_data/test_images/*')

async def main():

    async with aiohttp.ClientSession() as session:
        for image in images:
            async with session.post('http://localhost:8080/predictions/u2net', data=open(image, 'rb')) as resp:
                print(resp.status)

loop = asyncio.get_event_loop()
loop.run_until_complete(main())

but in TS log I can see, that requests are processed sequentially.

2021-09-03 09:11:58,652 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - predict_np shape (1, 320, 320)
2021-09-03 09:11:58,652 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - postprocessing image 0
2021-09-03 09:11:58,652 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - (3000, 2000)
2021-09-03 09:11:58,697 [INFO ] W-9000-u2net_1.0-stdout MODEL_METRICS - HandlerTime.Milliseconds:903.98|#ModelName:u2net,Level:Model|#hostname:52033c12a7e8,requestID:076a76ee-d493-4b47-9253-f3e81335ae91,timestamp:1630660318
2021-09-03 09:11:58,700 [INFO ] W-9000-u2net_1.0-stdout MODEL_METRICS - PredictionTime.Milliseconds:904.05|#ModelName:u2net,Level:Model|#hostname:52033c12a7e8,requestID:076a76ee-d493-4b47-9253-f3e81335ae91,timestamp:1630660318
2021-09-03 09:11:58,703 [INFO ] W-9000-u2net_1.0 org.pytorch.serve.wlm.WorkerThread - Backend response time: 925
2021-09-03 09:11:58,707 [INFO ] W-9000-u2net_1.0 ACCESS_LOG - /172.17.0.1:53032 "POST /predictions/u2net HTTP/1.1" 200 1182
2021-09-03 09:11:58,707 [INFO ] W-9000-u2net_1.0 TS_METRICS - Requests2XX.Count:1|#Level:Host|#hostname:52033c12a7e8,timestamp:null
2021-09-03 09:11:58,708 [DEBUG] W-9000-u2net_1.0 org.pytorch.serve.job.Job - Waiting time ns: 250420767, Backend time ns: 932032627
2021-09-03 09:11:58,708 [INFO ] W-9000-u2net_1.0 TS_METRICS - QueueTime.ms:250|#Level:Host|#hostname:52033c12a7e8,timestamp:null
2021-09-03 09:11:58,708 [INFO ] W-9000-u2net_1.0 TS_METRICS - WorkerThreadTime.ms:7|#Level:Host|#hostname:52033c12a7e8,timestamp:null
2021-09-03 09:11:59,542 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - predict_np shape (1, 320, 320)
2021-09-03 09:11:59,542 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - postprocessing image 0
2021-09-03 09:11:59,542 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - (1280, 720)
2021-09-03 09:11:59,550 [INFO ] W-9000-u2net_1.0-stdout MODEL_METRICS - HandlerTime.Milliseconds:574.61|#ModelName:u2net,Level:Model|#hostname:52033c12a7e8,requestID:9b5fdc78-df93-4c73-9eff-ae2bf4e3142f,timestamp:1630660319
2021-09-03 09:11:59,550 [INFO ] W-9000-u2net_1.0-stdout MODEL_METRICS - PredictionTime.Milliseconds:574.66|#ModelName:u2net,Level:Model|#hostname:52033c12a7e8,requestID:9b5fdc78-df93-4c73-9eff-ae2bf4e3142f,timestamp:1630660319
2021-09-03 09:11:59,551 [INFO ] W-9000-u2net_1.0 org.pytorch.serve.wlm.WorkerThread - Backend response time: 578
2021-09-03 09:11:59,551 [INFO ] W-9000-u2net_1.0 ACCESS_LOG - /172.17.0.1:53036 "POST /predictions/u2net HTTP/1.1" 200 832
2021-09-03 09:11:59,552 [INFO ] W-9000-u2net_1.0 TS_METRICS - Requests2XX.Count:1|#Level:Host|#hostname:52033c12a7e8,timestamp:null
2021-09-03 09:11:59,552 [DEBUG] W-9000-u2net_1.0 org.pytorch.serve.job.Job - Waiting time ns: 250759189, Backend time ns: 581669301
2021-09-03 09:11:59,552 [INFO ] W-9000-u2net_1.0 TS_METRICS - QueueTime.ms:250|#Level:Host|#hostname:52033c12a7e8,timestamp:null
2021-09-03 09:11:59,552 [INFO ] W-9000-u2net_1.0 TS_METRICS - WorkerThreadTime.ms:4|#Level:Host|#hostname:52033c12a7e8,timestamp:null
2021-09-03 09:12:00,421 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - predict_np shape (1, 320, 320)
2021-09-03 09:12:00,421 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - postprocessing image 0
2021-09-03 09:12:00,421 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - (770, 595)
2021-09-03 09:12:00,425 [INFO ] W-9000-u2net_1.0-stdout MODEL_METRICS - HandlerTime.Milliseconds:608.38|#ModelName:u2net,Level:Model|#hostname:52033c12a7e8,requestID:e36697cd-22a4-4074-9d68-b93927f7ef45,timestamp:1630660320
2021-09-03 09:12:00,425 [INFO ] W-9000-u2net_1.0 org.pytorch.serve.wlm.WorkerThread - Backend response time: 610

Context

We would like to batch multiple requests and do the inference just once for more requests.

Your Environment

There is a full repository to reproduce https://github.com/Biano-AI/TorchServe-u2net-handler

Custom handler


class U2NetHandler(BaseHandler):

    def preprocess(self, data):
        """
         Scales, crops, and normalizes a PIL image for a PyTorch model,
         returns an Numpy array
        """
        normalize = Compose([
            Resize((320, 320)),
            ToTensor(),
            Normalize(mean=[0.485, 0.456, 0.406],
                      std=[0.229, 0.224, 0.225])
        ])
        return torch.stack([normalize(im) for im in data])

    def _get_mask_bytes(self, img, mask):
        logger.info(img.size)
        return Image.fromarray(mask).resize(img.size, Image.BILINEAR).tobytes()

    def postprocess(self, images, output):
        pred = output[0][:, 0, :, :]
        predict = self._normPRED(pred)
        predict_np = predict.cpu().detach().numpy()
        logger.info(f'predict_np shape {predict_np.shape}')
        res = []
        i = 0
        for im in images:
            logger.info(f'postprocessing image {i}')
            mask = (predict_np[i] * 255).astype(np.uint8)
            res.append(self._get_mask_bytes(im, mask))
        return res

    # normalize the predicted SOD probability map
    # from oficial U^2-Net repo
    def _normPRED(self, d):
        ma = torch.max(d)
        mi = torch.min(d)
        dn = (d - mi) / (ma - mi)
        return dn

    def load_images(self, data):
        images = []
        for row in data:
            image = row.get("data") or row.get("body")
            if isinstance(image, str):
                image = base64.b64decode(image)
            image = Image.open(io.BytesIO(image))
            images.append(image)
        return images

    def handle(self, data, context):
        start_time = time.time()

        self.context = context
        metrics = self.context.metrics

        images = self.load_images(data)
        data_preprocess = self.preprocess(images)

        if not self._is_explain():
            output = self.inference(data_preprocess)
            output = self.postprocess(images, output)
        else:
            output = self.explain_handle(data_preprocess, data)

        stop_time = time.time()
        metrics.add_time('HandlerTime', round((stop_time - start_time) * 1000, 2), None, 'ms')
        return output

Expected Behavior

I understand from documentation, that TS should be able to aggregate multiple requests and call model just once. If not, please aplogoze...

Thanks

HamidShojanazeri commented 3 years ago

@toretak you need to register the model with the target batch size, this can be done either through management API for example

torchserve --start --model-store model_store

curl -X POST "localhost:8081/models?url=https://torchserve.pytorch.org/mar_files/resnet-152-batch_v2.mar&batch_size=3&max_batch_delay=10&initial_workers=1"

or through config.properties if you are using latest version, as indicated here.

toretak commented 3 years ago

Thanks for reply @HamidShojanazeri, ts is started with config.properties

torchserve --start --model-store model-store --ts-config /tmp/config.properties

config.properties file:

inference_address=http://0.0.0.0:8080
management_address=http://0.0.0.0:8081
metrics_address=http://0.0.0.0:8082
job_queue_size=100
load_models=u2net.mar
models={\
  "u2net": {\
    "1.0": {\
        "defaultVersion": true,\
        "marName": "u2net.mar",\
        "minWorkers": 1,\
        "maxWorkers": 1,\
        "batchSize": 8,\
        "maxBatchDelay": 1000,\
        "responseTimeout": 120\
    }\
  }\
}

I'v tried different configurations of maxBatchDelay (500, 1000, 5000) and there is no effect.

msaroufim commented 3 years ago

Can you print the the number of inferences being generated in the inference handler and check the logs/model_log.log for it? There should be as many inferences as your batch size

Also the BERT example is a good one to see how batching works - you are going to be giving a single batch to the inference function but you want to read back out an inference for each example and return it

https://github.com/pytorch/serve/blob/master/examples/Huggingface_Transformers/Transformer_handler_generalized.py

Also @toretak, I've noticed you're super active on torchserve so if you'd be interested in chatting over Zoom I'd love to set something up to go over your feedback and how you're using torchserve - my email is firstnamelastname@fb.com

toretak commented 3 years ago

Hi @msaroufim, thanks for your reply,

I have added debug print into inference and preprocess methods in custom handler and model_log.log says:

2021-10-04 15:41:58,086 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - Listening on port: /home/model-server/tmp/.ts.sock.9000
2021-10-04 15:41:58,086 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - [PID]32
2021-10-04 15:41:58,086 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - Torch worker started.
2021-10-04 15:41:58,086 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - Python runtime: 3.6.9
2021-10-04 15:41:58,097 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - Connection accepted: /home/model-server/tmp/.ts.sock.9000.
2021-10-04 15:41:58,118 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - model_name: u2net, batchSize: 8
2021-10-04 15:42:26,638 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - === preprocess called ===
2021-10-04 15:42:26,771 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - === inference in handler called ===
2021-10-04 15:42:26,804 [WARN ] W-9000-u2net_1.0-stderr MODEL_LOG - /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  /pytorch/c10/core/TensorImpl.h:1156.)
2021-10-04 15:42:26,805 [WARN ] W-9000-u2net_1.0-stderr MODEL_LOG -   return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
2021-10-04 15:42:26,810 [WARN ] W-9000-u2net_1.0-stderr MODEL_LOG - /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:3487: UserWarning: nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.
2021-10-04 15:42:26,810 [WARN ] W-9000-u2net_1.0-stderr MODEL_LOG -   warnings.warn("nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.")
2021-10-04 15:42:26,810 [WARN ] W-9000-u2net_1.0-stderr MODEL_LOG - /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:3613: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
2021-10-04 15:42:26,810 [WARN ] W-9000-u2net_1.0-stderr MODEL_LOG -   "See the documentation of nn.Upsample for details.".format(mode)
2021-10-04 15:42:27,267 [WARN ] W-9000-u2net_1.0-stderr MODEL_LOG - /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1805: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.
2021-10-04 15:42:27,268 [WARN ] W-9000-u2net_1.0-stderr MODEL_LOG -   warnings.warn("nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.")
2021-10-04 15:42:27,273 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - predict_np shape (1, 320, 320)
2021-10-04 15:42:27,273 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - postprocessing image 0
2021-10-04 15:42:27,273 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - (3000, 2000)
2021-10-04 15:42:27,611 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - === preprocess called ===
2021-10-04 15:42:27,675 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - === inference in handler called ===
2021-10-04 15:42:28,195 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - predict_np shape (1, 320, 320)
2021-10-04 15:42:28,195 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - postprocessing image 0
2021-10-04 15:42:28,195 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - (1280, 720)
2021-10-04 15:42:28,469 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - === preprocess called ===
2021-10-04 15:42:28,544 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - === inference in handler called ===
2021-10-04 15:42:29,061 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - predict_np shape (1, 320, 320)
2021-10-04 15:42:29,061 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - postprocessing image 0
2021-10-04 15:42:29,061 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - (770, 595)

config.properties is still the same

inference_address=http://0.0.0.0:8080
management_address=http://0.0.0.0:8081
metrics_address=http://0.0.0.0:8082
job_queue_size=100
load_models=u2net.mar
models={\
  "u2net": {\
    "1.0": {\
        "defaultVersion": true,\
        "marName": "u2net.mar",\
        "minWorkers": 1,\
        "maxWorkers": 1,\
        "batchSize": 8,\
        "maxBatchDelay": 1000,\
        "responseTimeout": 120\
    }\
  }\
}

according to log, it looks that TS really did three different inferences instead of one batched...

I read https://github.com/pytorch/serve/blob/master/examples/Huggingface_Transformers/Transformer_handler_generalized.py but I don't see anything special here. Am I missing something?

Thanks a lot...

msaroufim commented 3 years ago

Ok cool so the most likely culprit seems to be the batch delay, I see you have images of size (3000,2000) and (1280,720) so just as a sanity check try 10-100x the batch delay and see if the whole batch is now processed. Also remove the response time out for this test.

In transformer handler generalized I just wanted you to see how tensors are cat'ed before being passed to the inference handler https://github.com/pytorch/serve/blob/master/examples/Huggingface_Transformers/Transformer_handler_generalized.py#L156

toretak commented 3 years ago

so I've tested this config

inference_address=http://0.0.0.0:8080
management_address=http://0.0.0.0:8081
metrics_address=http://0.0.0.0:8082
job_queue_size=100
load_models=u2net.mar
models={\
  "u2net": {\
    "1.0": {\
        "defaultVersion": true,\
        "marName": "u2net.mar",\
        "minWorkers": 1,\
        "maxWorkers": 1,\
        "batchSize": 8,\
        "maxBatchDelay": 100000\
    }\
  }\
}

and it waited 100s between every image

model-server@6e6a671cfdf9:~$ cat logs/model_log.log 
2021-10-05 11:05:06,584 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - Listening on port: /home/model-server/tmp/.ts.sock.9000
2021-10-05 11:05:06,585 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - [PID]32
2021-10-05 11:05:06,585 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - Torch worker started.
2021-10-05 11:05:06,585 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - Python runtime: 3.6.9
2021-10-05 11:05:06,595 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - Connection accepted: /home/model-server/tmp/.ts.sock.9000.
2021-10-05 11:05:06,618 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - model_name: u2net, batchSize: 8
2021-10-05 11:07:07,358 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - === preprocess called ===
2021-10-05 11:07:07,480 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - === inference in handler called ===
2021-10-05 11:07:07,519 [WARN ] W-9000-u2net_1.0-stderr MODEL_LOG - /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  /pytorch/c10/core/TensorImpl.h:1156.)
2021-10-05 11:07:07,520 [WARN ] W-9000-u2net_1.0-stderr MODEL_LOG -   return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
2021-10-05 11:07:07,525 [WARN ] W-9000-u2net_1.0-stderr MODEL_LOG - /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:3487: UserWarning: nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.
2021-10-05 11:07:07,525 [WARN ] W-9000-u2net_1.0-stderr MODEL_LOG -   warnings.warn("nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.")
2021-10-05 11:07:07,525 [WARN ] W-9000-u2net_1.0-stderr MODEL_LOG - /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:3613: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
2021-10-05 11:07:07,525 [WARN ] W-9000-u2net_1.0-stderr MODEL_LOG -   "See the documentation of nn.Upsample for details.".format(mode)
2021-10-05 11:07:07,963 [WARN ] W-9000-u2net_1.0-stderr MODEL_LOG - /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1805: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.
2021-10-05 11:07:07,963 [WARN ] W-9000-u2net_1.0-stderr MODEL_LOG -   warnings.warn("nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.")
2021-10-05 11:07:07,964 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - predict_np shape (1, 320, 320)
2021-10-05 11:07:07,964 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - postprocessing image 0
2021-10-05 11:07:07,964 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - (3000, 2000)
model-server@6e6a671cfdf9:~$ cat logs/model_log.log 
2021-10-05 11:05:06,584 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - Listening on port: /home/model-server/tmp/.ts.sock.9000
2021-10-05 11:05:06,585 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - [PID]32
2021-10-05 11:05:06,585 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - Torch worker started.
2021-10-05 11:05:06,585 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - Python runtime: 3.6.9
2021-10-05 11:05:06,595 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - Connection accepted: /home/model-server/tmp/.ts.sock.9000.
2021-10-05 11:05:06,618 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - model_name: u2net, batchSize: 8
2021-10-05 11:07:07,358 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - === preprocess called ===
2021-10-05 11:07:07,480 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - === inference in handler called ===
2021-10-05 11:07:07,519 [WARN ] W-9000-u2net_1.0-stderr MODEL_LOG - /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  /pytorch/c10/core/TensorImpl.h:1156.)
2021-10-05 11:07:07,520 [WARN ] W-9000-u2net_1.0-stderr MODEL_LOG -   return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
2021-10-05 11:07:07,525 [WARN ] W-9000-u2net_1.0-stderr MODEL_LOG - /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:3487: UserWarning: nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.
2021-10-05 11:07:07,525 [WARN ] W-9000-u2net_1.0-stderr MODEL_LOG -   warnings.warn("nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.")
2021-10-05 11:07:07,525 [WARN ] W-9000-u2net_1.0-stderr MODEL_LOG - /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:3613: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
2021-10-05 11:07:07,525 [WARN ] W-9000-u2net_1.0-stderr MODEL_LOG -   "See the documentation of nn.Upsample for details.".format(mode)
2021-10-05 11:07:07,963 [WARN ] W-9000-u2net_1.0-stderr MODEL_LOG - /usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1805: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.
2021-10-05 11:07:07,963 [WARN ] W-9000-u2net_1.0-stderr MODEL_LOG -   warnings.warn("nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.")
2021-10-05 11:07:07,964 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - predict_np shape (1, 320, 320)
2021-10-05 11:07:07,964 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - postprocessing image 0
2021-10-05 11:07:07,964 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - (3000, 2000)
2021-10-05 11:08:48,038 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - === preprocess called ===
2021-10-05 11:08:48,062 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - === inference in handler called ===
2021-10-05 11:08:48,542 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - predict_np shape (1, 320, 320)
2021-10-05 11:08:48,542 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - postprocessing image 0
2021-10-05 11:08:48,542 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - (1280, 720)
2021-10-05 11:10:28,566 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - === preprocess called ===
2021-10-05 11:10:28,580 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - === inference in handler called ===
2021-10-05 11:10:29,178 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - predict_np shape (1, 320, 320)
2021-10-05 11:10:29,178 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - postprocessing image 0
2021-10-05 11:10:29,179 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - (770, 595)

There must be an issue in handler I guess ... I am trying to load all incoming images to list https://github.com/Biano-AI/TorchServe-u2net-handler/blob/master/src/custom_handler.py#L135

and then normalize all loaded images and convert them to tensors in preprocessing https://github.com/Biano-AI/TorchServe-u2net-handler/blob/master/src/custom_handler.py#L67-L78

Thanks a lot for your time @msaroufim !

msaroufim commented 3 years ago

How are you sending requests to torchserve? Are you using the request library by any change? In which case the call will be synchronous

Try sending 2 curl statements with & between them to send 2 parallel requests to see if the issue goes away

It feels like your handler is fine because predict_np shape is only getting one image at a time

toretak commented 3 years ago

@msaroufim you are absolutely right ... I was sending requests using this curl all the time

curl -X POST http://127.0.0.1:8080/predictions/u2net -T "{bike.jpg,boat.jpg,horse.jpg}"

and that was the issue

when are requests sent like this

curl -X POST http://127.0.0.1:8080/predictions/u2net -T "{bike.jpg}" & curl -X POST http://127.0.0.1:8080/predictions/u2net -T "{boat.jpg}"

batching works!

...
2021-10-08 07:29:12,653 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - Connection accepted: /home/model-server/tmp/.ts.sock.9000.
2021-10-08 07:29:12,674 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - model_name: u2net, batchSize: 8
2021-10-08 07:29:20,237 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - === preprocess called ===
2021-10-08 07:29:20,377 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - === inference in handler called ===
2021-10-08 07:29:21,382 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - predict_np shape (2, 320, 320)
2021-10-08 07:29:21,382 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - postprocessing image 0
2021-10-08 07:29:21,382 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - (1280, 720)
2021-10-08 07:29:21,388 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - postprocessing image 0
2021-10-08 07:29:21,389 [INFO ] W-9000-u2net_1.0-stdout MODEL_LOG - (3000, 2000)

so thank you very much! It's a quite surprise for me ...

Vert53 commented 2 years ago

This should really be added to the batch inferencing documentation as the example there only shows how to run 1 image. I was pretty confused until I stumbled on this issue.

https://pytorch.org/serve/batch_inference_with_ts.html

""" Run inference to test the model.

$ curl http://localhost:8080/predictions/resnet-152-batch_v2 -T kitten.jpg { "tiger_cat": 0.5848360657691956, "tabby": 0.3782736361026764, "Egyptian_cat": 0.03441936895251274, "lynx": 0.0005633446853607893, "quilt": 0.0002698268508538604 } """

msaroufim commented 2 years ago

Agreed lemme reopen this to keep track

Vert53 commented 2 years ago

@toretak were you able to get some python asyncio code working asynchronously with Torchserve API?

toretak commented 2 years ago

@Vert53 Hi, we actually doesn't need python asyncio in TS handler directly, so I didn't test it. In fact I can't imagine use case for it .. but it should work, I suppose. Do you have some (probably not working) implementation?

Vert53 commented 2 years ago

Hi @toretak what I meant Is to use asyncio for the requesting not the serving (handler). I managed to write this async code to test how fast torchserve worked on my setup using imagenet dataset. Sharing it in case it is of any use.

import json
import time
import aiohttp
import asyncio
import aiofiles
from torchvision.datasets import ImageFolder
from aiofiles.threadpool import AsyncBufferedReader
from typing import Tuple

valdir = '/pytorch/imagenet/ILSVRC2012/val'
index_to_name_path = 'index_to_name.json'  # mapping  {'791': ['n04204347', 'shopping_cart'],.....}

class ImageNetLoader:
    def __init__(self,
                 folder: ImageFolder):
        self.folder = folder
        self.iter_samples = iter(folder.samples)

    def __aiter__(self):
        return self

    async def __anext__(self) -> Tuple[AsyncBufferedReader, int]:
        try:
            sample_path, target = next(self.iter_samples)
        except StopIteration:
            raise StopAsyncIteration

        async with aiofiles.open(sample_path, 'rb') as sample_file:
            sample = await sample_file.read()

        return sample, target

async def infer_request(session: aiohttp.ClientSession,
                        url: str,
                        sample: AsyncBufferedReader,
                        target: int,
                        queue: asyncio.Queue
                        ) -> None:
    async with session.post(url, data=sample) as response:
        if response.status == 200:
            output = await response.text()
            await queue.put((output, target))

async def inference_session(url,
                            loader,
                            queue
                            ) -> None:
    async with aiohttp.ClientSession() as session:
        infers = []
        async for sample, target in loader:
            infers.append(asyncio.create_task(infer_request(session=session,
                                                            url=url,
                                                            sample=sample,
                                                            target=target,
                                                            queue=queue)))
        await asyncio.gather(*infers, return_exceptions=True

class ImagenetPostProcessor:
    def __init__(self):
        self.correct_predictions = 0
        self.total_predictions = 0

    async def postprocess_results(self,
                                  queue: asyncio.Queue,
                                  index_to_name: dict
                                  ) -> None:
        while True:
            output, target = await queue.get()
            top_1_prediction = next(iter(json.loads(output).keys()))
            target_str = index_to_name[str(target)][1]
            if top_1_prediction == target_str:
                self.correct_predictions += 1
            self.total_predictions += 1
            queue.task_done()

async def main(url, loader):
    postp = ImagenetPostProcessor()
    with open(index_to_name_path) as f:
        index_to_name = json.load(f) 
    queue = asyncio.Queue()
    producer = asyncio.create_task(inference_session(url, loader, queue))
    consumer = asyncio.create_task(postp.postprocess_results(queue, index_to_name),)
    await producer
    await queue.join()
    consumer.cancel()
    print(f'total pred {postp.total_predictions}')
    return postp.correct_predictions/postp.total_predictions

if __name__ == '__main__':
    imagenet_folder = ImageFolder(valdir)
    imagenet_loader = ImageNetLoader(imagenet_folder)
    a = time.time()
    correct = asyncio.run(
        main('http://localhost:8080/predictions/resnet50', imagenet_loader)
    )
    b = time.time()
    print(b - a)
    print(correct)
agunapal commented 2 years ago

Here's an example of sending async batched requests using python https://github.com/pytorch/serve/tree/master/examples/image_classifier/near_real_time_video