elixir-nx / nx

Multi-dimensional arrays (tensors) and numerical definitions for Elixir
2.66k stars 194 forks source link

Add quantized int types #1528

Closed josevalim closed 2 months ago

josevalim commented 2 months ago

EXLA suite is currently failing:

  2) test quantized types s2 (EXLA.BackendTest)
     test/exla/backend_test.exs:202
     ** (RuntimeError) dims and input_strides_in_bytes must have equal sizes, got 1 and 0
     code: tensor = Nx.s2([-2, -1, 1])
     stacktrace:
       (exla 0.8.0) lib/exla/device_buffer.ex:85: EXLA.DeviceBuffer.unwrap!/1
       (exla 0.8.0) lib/exla/device_buffer.ex:47: EXLA.DeviceBuffer.place_on_device/4
       (exla 0.8.0) lib/exla/backend.ex:46: EXLA.Backend.from_binary/3
       test/exla/backend_test.exs:203: (test)

Erlang does not allow us to pass bitstrings to C, only binaries (the number must be divisible by 8). So I am currently padding it, which should be fine, because all XLA cares is about a pointer at the beginning of the data. However, the call still fails. Here is the source that we call into:

https://github.com/openxla/xla/blob/091ef0c6ec9cb48b04d3f764b6c8b549b189d06c/xla/pjrt/pjrt_stream_executor_client.cc#L823

jonatanklosko commented 2 months ago

For the record https://github.com/openxla/xla/issues/16795.

josevalim commented 2 months ago

:green_heart: :blue_heart: :purple_heart: :yellow_heart: :heart: