Open hooman-bayer opened 1 year ago
If you're using FastAPI, the author (Tiangolo) has a nice project called Asyncer, which has a very nice asyncify function. This is just a wrapper on top of anyio which does the heavy lifting, but it's trivial to call the sync Sagemaker i/o within asyncio flows. Here's an example with Huggingface:
app = FastAPI()
@app.get("/completion")
async def get_completion():
return await get_huggingface_completion("Please give me the stuff!")
async def get_huggingface_completion(total_prompt: str) -> dict[str, Any]:
response = await asyncify(_communicate_with_sagemaker)(total_prompt)
return _get_completion(response, total_prompt)
def _communicate_with_sagemaker(total_prompt: str) -> Any:
session = boto3.Session(
region_name="us-west-2",
)
sage_session = sagemaker.Session(boto_session=session)
predictor = HuggingFacePredictor("your-endpoint", sagemaker_session=sage_session)
payload = {"inputs": total_prompt, "parameters": {"max_new_tokens": 1024}}
return cast(dict[str, Any], predictor.predict(payload))
If you're using FastAPI, the author (Tiangolo) has a nice project called Asyncer, which has a very nice asyncify function. This is just a wrapper on top of anyio which does the heavy lifting, but it's trivial to call the sync Sagemaker i/o within asyncio flows. Here's an example with Huggingface:
app = FastAPI() @app.get("/completion") async def get_completion(): return await get_huggingface_completion("Please give me the stuff!") async def get_huggingface_completion(total_prompt: str) -> dict[str, Any]: response = await asyncify(_communicate_with_sagemaker)(total_prompt) return _get_completion(response, total_prompt) def _communicate_with_sagemaker(total_prompt: str) -> Any: session = boto3.Session( region_name="us-west-2", ) sage_session = sagemaker.Session(boto_session=session) predictor = HuggingFacePredictor("your-endpoint", sagemaker_session=sage_session) payload = {"inputs": total_prompt, "parameters": {"max_new_tokens": 1024}} return cast(dict[str, Any], predictor.predict(payload))
With
@phillipuniverse thanks for the suggestion. Sure that helps but still it will run it on a different thread and its still different than a pure async
module which is more optimized for python.
Dropping a comment in case anyone else happens to stumble across this thread. The suggestion to use It appeared to be functional under low concurrency loads, but started failing a large percentage of requests once increasing the number of concurrent requests above ~5. asyncify
unfortunately didn't work.
Going to try directly issuing requests over HTTP with an async aware library and see if that does the trick.
Turns out that may not have been correct. After a lot of frustration it appears that the errors that were getting generated were actually coming from the Sagemaker endpoint (which in hindsight seems kind of obvious). The immediacy of the errors made me think it was something to do with how they were being initiated but now I don't think that was true.
In case this helps anyone else, this was how I implemented it by leveraging aiohttp
:
import hashlib
import json
import aiohttp
import boto3
from aws_request_signer import AwsRequestSigner
region_name = "region"
endpoint_name = "endpoint_name"
payload = {"inputs": "Test", "parameters": {}}
sagemaker_endpoint_url = f"https://runtime.sagemaker.{region_name}.amazonaws.com/endpoints/{endpoint_name}/invocations"
session = aiohttp.ClientSession()
_refreshable_credentials = boto3.Session(region_name=region_name).get_credentials()
# Get signed headers
creds = _refreshable_credentials.get_frozen_credentials()
signer = AwsRequestSigner(
region=region_name,
access_key_id=creds.access_key,
secret_access_key=creds.secret_key,
session_token=creds.token,
service="sagemaker",
)
payload_bytes = json.dumps(payload).encode("utf-8")
payload_hash = hashlib.sha256(payload_bytes).hexdigest()
headers = {"Content-Type": "application/json"}
headers.update(
signer.sign_with_headers("POST", sagemaker_endpoint_url, headers, payload_hash)
)
try:
async with session.post(sagemaker_endpoint_url, headers=headers, json=payload) as response:
response.raise_for_status()
return await response.json()
except aiohttp.ClientError as e:
raise RuntimeError(f"Request to SageMaker endpoint failed: {e}") from e
except Exception as e:
raise RuntimeError(f"An error occurred: {e}") from e
Only slightly pseudo code but that gives the general idea of how to go about signing the headers and would be easy to wrap in an async endpoint.
Thanks a lot @ewellinger . Have you tried to benchmark your approach against the following using asyncify
?
app = FastAPI()
@app.get("/completion")
async def get_completion():
return await get_huggingface_completion("Please give me the stuff!")
async def get_huggingface_completion(total_prompt: str) -> dict[str, Any]:
response = await asyncify(_communicate_with_sagemaker)(total_prompt)
return _get_completion(response, total_prompt)
def _communicate_with_sagemaker(total_prompt: str) -> Any:
session = boto3.Session(
region_name="us-west-2",
)
sage_session = sagemaker.Session(boto_session=session)
predictor = HuggingFacePredictor("your-endpoint", sagemaker_session=sage_session)
payload = {"inputs": total_prompt, "parameters": {"max_new_tokens": 1024}}
return cast(dict[str, Any], predictor.predict(payload))
I assume yours would be performing way better. It is beyond me why SageMaker would not offer such a basic feature. On one hand, I assume they want it to be used with LLMs but on the other hand no support for basic needs of it (e.g. async and streaming
)
I haven't, I only realized what was happening pretty late last night.
We're hitting these endpoints for LLM predictions so the latency is already pretty high and I'd imagine the difference between directly using aiohttp
and asyncify
would probably be negligible in this case.
Just a really quick and dirty comparison between the two implementations. This was hitting a LLM with 20 total requests, running 5 requests concurrently at a time.
This was leveraging the asyncer approach:
Summary:
Total: 60.2530 secs
Slowest: 18.1773 secs
Fastest: 9.3643 secs
Average: 13.5625 secs
Requests/sec: 0.3319
Total data: 45417 bytes
Size/request: 2390 bytes
Response time histogram:
9.364 [1] |■■■■■■■■■■
10.246 [0] |
11.127 [2] |■■■■■■■■■■■■■■■■■■■■
12.008 [3] |■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■
12.889 [1] |■■■■■■■■■■
13.771 [3] |■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■
14.652 [4] |■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■
15.533 [0] |
16.415 [3] |■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■
17.296 [0] |
18.177 [2] |■■■■■■■■■■■■■■■■■■■■
Latency distribution:
10% in 10.3807 secs
25% in 11.7108 secs
50% in 13.7912 secs
75% in 15.7309 secs
90% in 18.1773 secs
0% in 0.0000 secs
0% in 0.0000 secs
Details (average, fastest, slowest):
DNS+dialup: 0.0024 secs, 9.3643 secs, 18.1773 secs
DNS-lookup: 0.0017 secs, 0.0000 secs, 0.0083 secs
req write: 0.0004 secs, 0.0000 secs, 0.0034 secs
resp wait: 13.5570 secs, 9.3516 secs, 18.1769 secs
resp read: 0.0024 secs, 0.0002 secs, 0.0170 secs
Status code distribution:
[200] 19 responses
Error distribution:
[1] Post "http://localhost:8000/api/llm/text_generation/v1": context deadline exceeded (Client.Timeout exceeded while awaiting headers)
Here was the breakdown with aiohttp:
Summary:
Total: 50.7949 secs
Slowest: 16.6846 secs
Fastest: 5.4675 secs
Average: 11.4514 secs
Requests/sec: 0.3937
Total data: 47699 bytes
Size/request: 2384 bytes
Response time histogram:
5.468 [1] |■■■■■■
6.589 [0] |
7.711 [1] |■■■■■■
8.833 [1] |■■■■■■
9.954 [2] |■■■■■■■■■■■
11.076 [1] |■■■■■■
12.198 [7] |■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■
13.319 [3] |■■■■■■■■■■■■■■■■■
14.441 [1] |■■■■■■
15.563 [2] |■■■■■■■■■■■
16.685 [1] |■■■■■■
Latency distribution:
10% in 8.8009 secs
25% in 10.2199 secs
50% in 11.4831 secs
75% in 13.1079 secs
90% in 15.2409 secs
95% in 16.6846 secs
0% in 0.0000 secs
Details (average, fastest, slowest):
DNS+dialup: 0.0023 secs, 5.4675 secs, 16.6846 secs
DNS-lookup: 0.0020 secs, 0.0000 secs, 0.0093 secs
req write: 0.0002 secs, 0.0000 secs, 0.0017 secs
resp wait: 11.4449 secs, 5.4571 secs, 16.6785 secs
resp read: 0.0039 secs, 0.0001 secs, 0.0286 secs
Status code distribution:
[200] 20 responses
So it does look like there is a benefit to using aiohttp but would probably need more extensive testing to say how large the difference would be.
Also I can confirm that using asyncer
does, in fact, work. The issues I hit previously were actually on the endpoint side.
Awesome @ewellinger , thanks a lot! looks like your approach is slightly better. SageMaker has recently introduced invoke_endpoint_with_response_stream
but still is a synchronous operation (in sagemaker python sdk) but with your approach one can get close to a decent streaming
and async
approach (something like below):
import hashlib
import json
import aiohttp
import boto3
from aws_request_signer import AwsRequestSigner
async def invoke_sagemaker_stream(region_name, endpoint_name, payload):
sagemaker_endpoint_url = f"https://runtime.sagemaker.{region_name}.amazonaws.com/endpoints/{endpoint_name}/invocations-response-stream"
async with aiohttp.ClientSession() as session:
_refreshable_credentials = boto3.Session(region_name=region_name).get_credentials()
# Get signed headers
creds = _refreshable_credentials.get_frozen_credentials()
signer = AwsRequestSigner(
region=region_name,
access_key_id=creds.access_key,
secret_access_key=creds.secret_key,
session_token=creds.token,
service="sagemaker",
)
payload_bytes = json.dumps(payload).encode("utf-8")
payload_hash = hashlib.sha256(payload_bytes).hexdigest()
headers = {"Content-Type": "application/json"}
headers.update(
signer.sign_with_headers("POST", sagemaker_endpoint_url, headers, payload_hash)
)
try:
async with session.post(sagemaker_endpoint_url, headers=headers, json=payload) as response:
response.raise_for_status()
# Now, instead of returning a JSON response, we handle the stream.
async for line in response.content:
print(line.decode('utf-8'))
# Process each line here as needed
except aiohttp.ClientError as e:
raise RuntimeError(f"Request to SageMaker endpoint failed: {e}") from e
except Exception as e:
raise RuntimeError(f"An error occurred: {e}") from e
# Example usage:
# asyncio.run(invoke_sagemaker_stream("your-region", "your-endpoint", {"inputs": "Test", "parameters": {}}))
+1 on this, would be great to have
I'm not terribly hopeful that this will be implemented in a timely manner.
My company has priority support and I raised a support ticket to have them respond to this feature request and it was closed not too long after without them chiming in here. I didn't have the time for the back and forth since the implementation above technically works :/
Describe the feature you'd like Like many other inference libraries in python (e.g. OpenAI), create a real awaitable version of Predict for realtime sagemaker inference endpoints. This will help python applications that use FastAPI and asyncio to deliver realtime responses while not blocking the main event loop. Please note that this feature is different that the one currently available here where the predictions are written to a S3 bucket. This feature would work exactly like https://sagemaker.readthedocs.io/en/stable/api/inference/predictors.html#sagemaker.predictor.Predictor.predict but with an
await
in realasyncio
style.Sagemaker is an amazing library and it would be just way better for production environments using FastAPI to have this feature.
How would this feature be used? Please describe. In this case, currently, the sync version looks like this:
The async might be looking like
Describe alternatives you've considered I considered subclassing the
predictor
and add the async version.Additional context For modern python applications building on top of FastAPI and Asyncio, it is crucial to use async modalities do avoid blocking the main event-loop in the server (in case of scalable applications). Therefore, having a real
awaitable
functionality would avoid blocking the main event loop of the applications that leverage sagemaker.Thanks alot