pytorch / serve

Serve, optimize and scale PyTorch models in production
https://pytorch.org/serve/
Apache License 2.0
4.04k stars 821 forks source link

Fix sequence continuous batching close session race condition #3198

Closed namannandan closed 1 week ago

namannandan commented 1 week ago

Description

In the current Sequence Batching event dispatcher implementation, we do the following: https://github.com/pytorch/serve/blob/079ff7b5d31d79627fa73d247c4f86a12893f8c5/frontend/server/src/main/java/org/pytorch/serve/wlm/SequenceBatching.java#L187-L218 Every maxBatchDelay interval:

  1. We either queue a task to poll for new job groups that the worker can process jobs from or
  2. queue a task to poll an existing job group queue for new jobs

As a result of this we have the following outcomes:

  1. Duplicate tasks to poll for new job groups which may hold up more than one thread in the poll executor thread pool
  2. Duplicate tasks to poll the same existing job group queue which affects job group clean up on close session requests

Concretely, the issue is triggered in the following scenario:

  1. maxNumSequence number of sessions are actively open
  2. A sequence gets a stream response request
  3. The same sequence subsequently gets a close session request
  4. Although the sequence is closed and should free up capacity to open up a new session, it holds session capacity until the session times out and only then gets cleaned up.

In summary, a stream response request prevents graceful session closure and cleanup.

The likely root cause here is that the session cleanup logic fails to detect session closure after stream response because, the logic to poll jobs from an existing job group would have already gone past the point where we detect closed sessions: https://github.com/pytorch/serve/blob/079ff7b5d31d79627fa73d247c4f86a12893f8c5/frontend/server/src/main/java/org/pytorch/serve/wlm/SequenceBatching.java#L220-L226

Moreover, with stream response, for each chunk we send back, we will enqueue a poll job group task on the poll executor queue although we would expect to have only one such active task at a given point in time.

Type of change

Please delete options that are not relevant.

Feature/Issue validation/testing

Without the fix in this PR, the test fails as follows:

$ python -m pytest test/pytest/test_example_stateful_sequence_continuous_batching_http.py::test_infer_stateful_cancel
.....
.....
  AssertionError: assert '{\n  "code":...eueSize"\n}\n' == '-1'
    - -1
    + {
    +   "code": 503,
    +   "type": "ServiceUnavailableException",
    +   "message": "Model \"stateful\" has no worker to serve inference request. Please use scale workers API to add workers. If this is a sequence inference, please check if it is closed, or expired; or exceeds maxSequenceJobQueueSize"
    + }

With the fix in this PR:

$ python -m pytest test/pytest/test_example_stateful_sequence_continuous_batching_http.py::test_infer_stateful_cancel

=============================================================== test session starts ===============================================================
platform darwin -- Python 3.9.6, pytest-7.3.1, pluggy-1.5.0
rootdir: /Volumes/workplace/pytorch/serve
plugins: cov-4.1.0, timeout-2.3.1, mock-3.14.0
collected 1 item                                                                                                                                  

test/pytest/test_example_stateful_sequence_continuous_batching_http.py .                                                                    [100%]

================================================================ warnings summary =================================================================
venvs/ts_dev/lib/python3.9/site-packages/urllib3/__init__.py:35
  /Volumes/workplace/pytorch/serve/venvs/ts_dev/lib/python3.9/site-packages/urllib3/__init__.py:35: NotOpenSSLWarning: urllib3 v2 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========================================================== 1 passed, 1 warning in 15.58s ==========================================================