elixir-nx / nx

Multi-dimensional arrays (tensors) and numerical definitions for Elixir
2.66k stars 194 forks source link

[Proposal] `Nx.DynamicServing` #1501

Closed Benjamin-Philip closed 5 months ago

Benjamin-Philip commented 5 months ago

Introduction

Nx's Nx.Serving provides a way to batch computations, as well as distribute these batches across multiple nodes and partitions. However, the mechanisms it provides to react to variation in traffic are limited. Two instances where improvements can be made come to mind:

Firstly, Serving instances that are idling can hibernate. However, the GPU on that node is still idling and still billed. GPUs are expensive, and releasing the node can lead to significant cost improvements.

Secondly, there is no way to react to a large spike in traffic. If there is a spike significantly larger than the traffic for which the serving has tuned for, the system will fall behind the incoming requests, which will further lead to a significant lags. In such a scenario, the ideal course of action is to increase the number of available servings.

In this issue I propose Nx.DynamicServing, a way to elastically scale servings (i.e. have a dynamic number of servings), enabling users to tune their servings for the common scenario without worrying over edge cases.

Basic Architecture

Nx.DynamicServing will be a singleton GenServer (maybe a distributed application?) that manages the servings. In essence all it does is add or remove servings (i.e. processes in Nx.Serving.PG) according to load. Thus users start a link to Nx.DynamicServing instead of starting a link to Nx.Serving. Since we just update the process group, users can do a batched run as usual. The actual nodes will be added or remove with FLAME by placing each serving as a child on a runner in a FLAME.Pool.

For starting the FLAME Pool we have 2 options: either we let the user start the pool as a child of their application and accept the pool name, or we can start the FLAME Pool as a child of DynamicServing, accepting all the parameters used of configuring FLAME Pools in addition to Nx.Serving. The name of said pool can be Nx.DynamicServing.ServingRunner. I personally prefer the latter approach.

Why build on top of Nx.Serving instead of replacing it?

By running the servings on top of FLAME, you loose flexibility in how you run them. In the event that a user sees no need to autoscale their nodes, customize their clustering, or even run on a platform not yet supported by FLAME, Nx.DynamicServing is a downgrade from Nx.Serving. We do not want to force FLAME on our users, and therefore I feel we should retain Nx.Serving, and build Nx.DynamicServing on top of it.

Scaling down

Choosing when to shutdown a node is easy - we shut it down after a timeout when idling (i.e. there are no enqueued batches or requests). The timeout to shutdown a Nx.Serving will be given by a new parameter to Serving, :serving_timeout which overrides :hibernate_after. The node will then be immediately shutdown. (i.e. FLAME's :idle_shutdown_after is 0).

Scaling up

Choosing when to scale up is comparatively difficult. One way which I propose is to limit the maximum number of enqueued batches for each partition/node to some number, :max_batches. If all nodes have reached this number (i.e. no node can accept requests) on an inference request, we can launch another node.

This can be internally represented by only including available batch "slots" in the serving process group. Thus we launch another node when the process group is empty. Once a serving frees a batch slot, that slot can re-appear on the process group.

An advantage of representing the limit in this manner is that we route more requests to servings with more free slots, thereby uniformly using our resources, while still maintaining the simple and random nature of the current load balancing implementation.

This is just one way of deciding when to and how to scaling up. There may be other ways, and I'm open discussing them and amending the proposal accordingly.

Using DynamicServing and plain Distributed Serving simultaneously

Some users may wish to maintain a fixed set of nodes without shutting them down to combat coldstarts, and thus start plain Servings in addition to the DynamicServing. I do not propose that such functionality be supported as this would mean that these Distributed Servings will be in PG permanently, breaking our upscaling mechanism.

Instead, we can accept a minimum number of servings, :min_servings parameter, and create these in the pool with the :serving_timeout and :hibernate_after parameters nullified on init.

(/cc @seanmor5).

josevalim commented 5 months ago

Thanks @Benjamin-Philip!

We can do improvements in this direction but I believe most of it should actually happen outside of Nx. You want a load balancer that is better than the current random one and I believe this can be implemented in a way that is agnostic to Nx.Serving, such that it benefits any shared resources running on the BEAM, not only Nx.

For example, if you want to call Nx.Serving on a node, you call the load balancer, which tells you the node to find Nx.Serving one, and then you message it directly. The load balancer will keep track of on-going requests, and knows when to bring up and shutdown unused instances. If you want to track batches, you can give a weight whenever you request the load balancer (this request has a weight of 4). I don't think Nx even needs to know it exists (in the same way the serving implementation doesn't really know about the current load balancer, the name registration we have today is orthogonal to the serving logic itself, it is just a very simple mechanism for loading balancing that comes out of the box).

Benjamin-Philip commented 5 months ago

such that it benefits any shared resources running on the BEAM, not only Nx.

Do you have any examples of any resources that would benefit from such load balancing today?

I believe this can be implemented in a way that is agnostic to Nx.Serving,

Sure, we can treat each serving as a worker and balance. But then we need to consider if the worker batches or not. If we need to batch on a node, do we partition? How do we partition? My question is how agnostic do you suggest? Are you suggesting that we build just a load balancer, or a load balancer that can additionally batch process?

josevalim commented 5 months ago

Any application that takes too long to boot and/or requires batching would benefit.

And yes, I propose we build a generic load balancer. It doesn’t need the concept of batches, it only needs to understand some loads have a bigger weight (because it has multiple batches).

Benjamin-Philip commented 5 months ago

I'll close this issue since it is out of the scope of Nx itself.

Would you like me to discuss the specifics of the load balancer privately with you or publicly in this issue (in case @polvalente or @seanmor5 have an opinion)?

josevalim commented 5 months ago

@Benjamin-Philip we can discuss it here (in this issue) or on the #machine-learning Slack of EEF or whatever else you prefer!

Benjamin-Philip commented 5 months ago

To summarize, we want a load balancer to load balance worker pids. I had a resource-based load balancer in the serving in mind, but for now scoring all the resources into a weight as @josevalim suggested seems good enough to start with. Since this weight is dynamic, we can also include a static weight to denote differences that do not change (like a worker on a RTX 4090 vs A100).

We can let the load balancer shutdown workers after a certain timeout, so, we need not depend on the worker to enforce our rules. In case workers need to be shutdown by some other metric, a timeout of :infinity can be specified, and that worker can shutdown itself. But, this does mean we will need to monitor idling/log requests.

However we still need to decide when we want to scale up. A potential idea could be when all workers have a resource weight of 0 or less. This could be cost effective, but the throughput will be decreasing as each node reaches maximum capacity. Another idea would be when a node reaches a resource weight of 0 or less. This maintains a constant throughput (as the number of request accepting nodes is constant), but may not be cost effective. I also have a feeling that different scenarios would have different parameters to optimize for. Maybe we should include multiple load balancing (and scaling) algorithms?

Another (simpler) question I have is what should this project be called. Also, should it be under the Nx namespace/organization, or should I maintain it under my own user, and under its own namespace?

-- bp

On Mon, 3 Jun 2024 at 16:51, José Valim @.***> wrote:

@Benjamin-Philip https://github.com/Benjamin-Philip we can discuss it here (in this issue) or on the #machine-learning Slack of EEF or whatever else you prefer!

— Reply to this email directly, view it on GitHub https://github.com/elixir-nx/nx/issues/1501#issuecomment-2144941634, or unsubscribe https://github.com/notifications/unsubscribe-auth/AQDKWBIHOIMTY46TXNNZJJTZFRGSVAVCNFSM6AAAAABIVKHAJWVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCNBUHE2DCNRTGQ . You are receiving this because you were mentioned.Message ID: @.***>

josevalim commented 5 months ago

Keep in mind that this can become quite interesting. For example, if you want to scale up, you need to make sure that you have enough requests incoming so you are not spawning a new machine for something that was a temporary spike. In particular, if it takes 3s to boot a new machine (probably more because of ML models), you need to make sure there are enough requests and that they won't all be processed before the machine finishes booting.

Those problems are discussed under Queueing Theory. For example, you could measure how frequently requests arrive , how many FLAMEs there are, and how often it takes for them to be processed (and probably more, it has been a while!). Based on that, you can start inferring how long it takes to go through the whole queue, and therefore if spawning new resources would be worthwhile. For example, this is what can estimate in call centers that "you are 10th in line, it will take 12 minutes".

More interesting, queueing theory began with a Danish named Agner Krarup Erlang. I am not asking you to build a queueing theory system but, if you like this stuff, it can be quite fun to learn, and it may point you to a direction where you model the system not by how much "slots" are available or if a batch is full but rather as queues. :)

I think this is more of a FLAME related project, so you personal user is probably the best for now!

PS: it has been like 15 years since I last studied queueing theory, so i probably got some details wrong, but the general idea is here!