pytorch / serve

Serve, optimize and scale PyTorch models in production
https://pytorch.org/serve/
Apache License 2.0
4.18k stars 853 forks source link

Long wait times for first request from TorchScript model #24

Closed fbbradheintz closed 3 years ago

fbbradheintz commented 4 years ago

I have two identical models, one in code + weights, the other in TorchScript. Doing inference with TorchScript takes far, far longer, which is surprising.

The setup:

The non-TorchScript model is just the DenseNet-161 model archive from the README.me quick start.

The TorchScript model is the same one, but exported to TorchScript thus:

import torch
import torchvision
d161 = torchvision.models.densenet161(pretrained=True)
tsd161 = torch.jit.script(d161)
tsd161.save('tsd161.pt')

It was then packaged with:

torch-model-archiver --model-name tsd161 --version 1.0 --serialized-file tsd161.pt --handler image_classifier

The server is started with:

torchserve --start --model-store model_store --models densenet161=densenet161.mar tsd161=tsd161.mar

This is the timing output from calling the regular model:

time curl -X POST http://127.0.0.1:8080/predictions/densenet161 -T kitten.jpg
[
  {
    "tiger_cat": 0.46933549642562866
  },
  {
    "tabby": 0.4633878469467163
  },
  {
    "Egyptian_cat": 0.06456148624420166
  },
  {
    "lynx": 0.0012828214094042778
  },
  {
    "plastic_bag": 0.00023323034110944718
  }
]
curl -X POST http://127.0.0.1:8080/predictions/densenet161 -T kitten.jpg  0.01s user 0.01s system 2% cpu 0.428 total

And from the TorchScript:

time curl -X POST http://127.0.0.1:8080/predictions/tsd161 -T kitten.jpg
[
  {
    "282": "0.46933549642562866"
  },
  {
    "281": "0.4633878469467163"
  },
  {
    "285": "0.06456148624420166"
  },
  {
    "287": "0.0012828214094042778"
  },
  {
    "728": "0.00023323034110944718"
  }
]curl -X POST http://127.0.0.1:8080/predictions/tsd161 -T kitten.jpg  0.01s user 0.01s system 0% cpu 1:16.54 total

The identical output between the two (except for the human-readable labels) shows we're dealing with the same model in both instances.

I'm marking this launch blocking, at least until we understand what's happening.

fbbradheintz commented 4 years ago

Note that this isn't a "lag time loading the model" issue - repeated attempts give similar results.

I'll try it with some other models as well, to see how consistent the issue is.

harshbafna commented 4 years ago

@fbbradheintz This seems like PyTorch specific issue.

Please find attached sample prediction code for densenet161 model in eager and torchscript mode.

test_torchscript.txt test_eager.txt

real 1m56.721s user 1m56.270s sys 0m0.950s


- Eager mode execution time

(base) USL07109 harsh_bafna$ time python test_eager.py ['n02123045', 'tabby'] --- 1.1276381015777588 seconds ---

real 0m1.974s user 0m1.975s sys 0m0.409s



We also found following open issue in PyTorch related to performance issue in TorchScript mode :
[https://github.com/pytorch/pytorch/issues/30365](https://github.com/pytorch/pytorch/issues/30365)
fbbradheintz commented 4 years ago

Ah - I hadn't seen that issue. Will investigate from my side. I'll keep this issue open in the meantime.

fbbradheintz commented 4 years ago

@harshbafna Can you share the scripts you used for that test?

fbbradheintz commented 4 years ago

Thanks to @nairbv for the tandem diagnosis on this.

The issue only shows on the model's first forward pass. There's a bunch of precompilation that needs to happen for TorchScript to execute an inference. After that happens, things get much faster. I'll verify that once a worker is hit once, perf improves.

At this time, the only way to kick off this precompilation is to perform a forward pass. We discussed different ways to accommodate this:

The latter could be exposed as a single, optional flag on torchserve, something like:

torchserve --start --blahblah --sample_input=valid_input.pt

For the time being, this doesn't need to block launch, but we should make a plan to improve this in future revs.

nairbv commented 4 years ago

filed JIT ticket for potential improvements: https://github.com/pytorch/pytorch/issues/33354

jeremiahschung commented 4 years ago

Confirmed fix in the latest 1.7 RC thanks to the fix in pytorch/pytorch#33354.

Followed steps in the original issue description to create a torchscripted densenet model and served it with TS.

time curl -X POST http://127.0.0.1:8080/predictions/tsd161 -T kitten.jpg
{
  "282": 0.4693361222743988,
  "281": 0.4633875787258148,
  "285": 0.06456127017736435,
  "287": 0.0012828144244849682,
  "728": 0.00023322943889070302
}
real    0m0.496s
user    0m0.004s
sys 0m0.004s
time curl -X POST http://127.0.0.1:8080/predictions/tsd161 -T kitten.jpg
{
  "282": 0.4693361222743988,
  "281": 0.4633875787258148,
  "285": 0.06456127017736435,
  "287": 0.0012828144244849682,
  "728": 0.00023322943889070302
}
real    0m0.049s
user    0m0.008s
sys 0m0.000s

@chauhang , can we close this issue now or wait until 1.7 is out?

harshbafna commented 3 years ago

Validated this on the latest master with PT 1.7 on a p3.8xlarge instance with 4 model workers each loaded on a different GPU device and response time is 1.4 seconds

ubuntu@ip-172-31-73-130:~$ time curl -X POST http://localhost:8080/predictions/densenet161_scripted -T serve/examples/image_classifier/kitten.jpg 
{
  "tiger_cat": 0.46933576464653015,
  "tabby": 0.463387668132782,
  "Egyptian_cat": 0.06456146389245987,
  "lynx": 0.0012828221078962088,
  "plastic_bag": 0.00023323048662859946
}
real    0m1.344s
user    0m0.000s
sys 0m0.006s
ubuntu@ip-172-31-73-130:~$ 
ubuntu@ip-172-31-73-130:~$ time curl -X POST http://localhost:8080/predictions/densenet161_scripted -T serve/examples/image_classifier/kitten.jpg 
{
  "tiger_cat": 0.46933576464653015,
  "tabby": 0.463387668132782,
  "Egyptian_cat": 0.06456146389245987,
  "lynx": 0.0012828221078962088,
  "plastic_bag": 0.00023323048662859946
}
real    0m1.347s
user    0m0.000s
sys 0m0.006s
ubuntu@ip-172-31-73-130:~$ 
ubuntu@ip-172-31-73-130:~$ time curl -X POST http://localhost:8080/predictions/densenet161_scripted -T serve/examples/image_classifier/kitten.jpg 
{
  "tiger_cat": 0.46933576464653015,
  "tabby": 0.463387668132782,
  "Egyptian_cat": 0.06456146389245987,
  "lynx": 0.0012828221078962088,
  "plastic_bag": 0.00023323048662859946
}
real    0m1.394s
user    0m0.000s
sys 0m0.006s
ubuntu@ip-172-31-73-130:~$ 
ubuntu@ip-172-31-73-130:~$ time curl -X POST http://localhost:8080/predictions/densenet161_scripted -T serve/examples/image_classifier/kitten.jpg 
{
  "tiger_cat": 0.46933576464653015,
  "tabby": 0.463387668132782,
  "Egyptian_cat": 0.06456146389245987,
  "lynx": 0.0012828221078962088,
  "plastic_bag": 0.00023323048662859946
}
real    0m1.374s
user    0m0.000s
sys 0m0.006s
ubuntu@ip-172-31-73-130:~$ 

Closing the ticket.

ozancaglayan commented 2 years ago

Hi,

sorry to bring this up but I thought that this may be the right place.

I'm also having a similar issue with torchserve-nightly, but interestingly with an eager model. After launching the server, the forward-pass for the very first HTTP request takes around 320ms whereas the subsequent ones take around 9ms. I've measured times of different snippets and indeed, it's the forward call that takes 99% of this time.

Do you have any ideas?

nairbv commented 2 years ago

@ozancaglayan That sounds like a distinct issue, so might want to file a separate one. Is this difference torchserve specific, or is it something you can reproduce without torchserve?

ozancaglayan commented 2 years ago

sorry, ignore this, i didnt notice that the model was getting deployed onto gpu without any further setup, so that overhead is probably due to the model being on gpu, some CUDA cache coldness. now there seems to be still a slight lag in first calls on CPU, though probably negligible.

Thanks!