elixir-nx / xla

Pre-compiled XLA extension
Apache License 2.0
83 stars 21 forks source link

How to get the number of client devices #67

Closed aramallo closed 6 months ago

aramallo commented 6 months ago

Hi,

I am starting to learn Nx and friends. I was searching for a function to determine how many client devices I have available and found EXLA.Client.get_supported_platforms() which returns a map of clients and (supposedly) number of devices.

iex(3)> EXLA.Client.get_supported_platforms()
%{host: 10, interpreter: 1}

However, if I try to create an NX.Serving with device_id > 0 I get an error saying there is no such device.

...
serving = Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer,
  compile: [batch_size: 32, sequence_length: 8],
  defn_options: [compiler: EXLA, device_id:2]
)
...
 iex(13)> Nx.Serving.run(serving, text)
** (RuntimeError) No matching device found for device_id 2
    (exla 0.6.4) lib/exla/device_buffer.ex:55: EXLA.DeviceBuffer.unwrap!/1
    (exla 0.6.4) lib/exla/device_buffer.ex:31: EXLA.DeviceBuffer.copy_to_device/3
    (exla 0.6.4) lib/exla/defn/buffers.ex:25: EXLA.Defn.Buffers.filter_by_indexes_map/4
    (exla 0.6.4) lib/exla/defn.ex:342: EXLA.Defn.maybe_outfeed/7
    (stdlib 5.1.1) timer.erl:270: :timer.tc/2
    (exla 0.6.4) lib/exla/defn.ex:283: anonymous fn/7 in EXLA.Defn.__compile__/4
    (nx 0.6.4) lib/nx/defn.ex:313: anonymous fn/4 in Nx.Defn.compile/3
    iex:13: (file)

Then looking at the implementation I found a call to EXLA.NIF.get_device_count/1 which returns 1 as result.

iex(4)> {_, ref} = EXLA.NIF.get_host_client()
{"message":"TfrtCpuClient created.","time":"2023-12-14T21:07:57.648Z","severity":"INFO"}
{:ok, #Reference<0.3908760169.3496607746.123882>}
iex(5)> EXLA.NIF.get_device_count(ref)
1

Is the difference between these two results correct? Am I misunderstanding the return of EXLA.Client.get_supported_platforms() ?

jonatanklosko commented 6 months ago

The number EXLA.Client.get_supported_platforms() gives for :host is the number of CPU cores, but in practice there's only once device. You can force more devices with XLA_FLAGS=--xla_force_host_platform_device_count=10, but that's meant for testing, because XLA should already use more cores whenever appropriate. Multiple devices is relevant when you literally have multiple GPU devices attached to the machine :)

aramallo commented 6 months ago

Thanks @jonatanklosko thats very clear.

My naive idea was to run N servings of a SentenceTransformer, where N was the number of cores. Do you mind if I ask here what would be the best way to do that? Say I want to compute embeddings for a list of inputs using those cores.

Thanks!

josevalim commented 6 months ago

The general answer from XLA is that it is best to not do that and instead pass a batch of sentences that will be processed in parallel using all cores by default.

That's the theory but XLA for CPU is not as fast as it could be nor it uses all cores all the time (most optimizations are done when running on the GPU). You could set XLA_FLAGS=--xla_force_host_platform_device_count=10 and then you would have multiple devices, which you could run with multiple servings, but in my experience that doesn't make a difference.

josevalim commented 6 months ago

Btw, thank you for all work on Partisan and that collection of libraries. ❤️ If there is anything we can help with, don't hesitate to ask (we are also on the #machine-learning channel of the Erlang Ecosystem Foundation Slack).

aramallo commented 6 months ago

@josevalim Thank you so much for the answer and the kind words!! I will move future discussions to the EEF slack. I recently started working with Elixir and these libraries and I am completely blown away, so thanks to you and team for an extraordinary work.