elixir-nx / nx_iree

Elixir and Nx bindings for the IREE runtime and compiler
Apache License 2.0
28 stars 2 forks source link

Increase Nx API support when using IREE #13

Open kevinschweikert opened 1 month ago

kevinschweikert commented 1 month ago

As seen in the doctest file, IREE doesn't support, for one reason or another, a significant chunk of the Nx API. These should all be fixable, but there are a few different error classes we need to deal with.


Original issue: I am trying out this project with the following livebook code on a MacBook M3 Pro:

Mix.install(
  [
    {:bumblebee, github: "elixir-nx/bumblebee"},
    {:nx_iree, github: "elixir-nx/nx_iree"},
    {:kino, "~> 0.14.1"},
    {:nx, github: "elixir-nx/nx", sparse: "nx", override: true},
  ]
)

NxIREE.list_drivers() |> IO.inspect(label: "drivers")

{:ok, [metal | _]} = NxIREE.list_devices("metal")

flags = [
  "--iree-hal-target-backends=metal-spirv",
  "--iree-input-type=stablehlo_xla",
  "--iree-execution-model=async-internal"
]

Nx.Defn.global_default_options(
  compiler: NxIREE.Compiler,
  iree_compiler_flags: flags,
  iree_runtime_options: [device: metal]
)

Nx.global_default_backend(NxIREE.Backend)

model = "openai/whisper-tiny"

{:ok, whisper} = Bumblebee.load_model({:hf, model})
{:ok, featurizer} = Bumblebee.load_featurizer({:hf, model})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, model})
{:ok, generation_config} = Bumblebee.load_generation_config({:hf, model})

serving_iree =
  Bumblebee.Audio.speech_to_text_whisper(whisper, featurizer, tokenizer, generation_config,
    chunk_num_seconds: 30,
    timestamps: :segments
  )

Kino.start_child!({Nx.Serving, serving: serving_iree, name: WhisperIREE})

file = "path/to/video.mp4"

transcript = Nx.Serving.batched_run(WhisperIREE, {:file, file})
dbg transcript

I've got it to start the serving but failed at the Nx.Serving.batched_run step with

11:25:16.009 [error] Task #PID<0.1570.0> started from WhisperIREE terminating
** (MatchError) no match of right hand side value: {"", 1}
    (nx_iree 0.0.1-pre.7) lib/nx_iree.ex:29: NxIREE.compile/3
    (nx_iree 0.0.1-pre.7) lib/nx_iree/compiler.ex:53: NxIREE.Compiler.__compile__/4
    (nx_iree 0.0.1-pre.7) lib/nx_iree/compiler.ex:76: NxIREE.Compiler.__jit__/5
    (nx 0.9.0) lib/nx/defn.ex:452: Nx.Defn.do_jit_apply/3
    (bumblebee 0.5.3) lib/bumblebee/audio/speech_to_text_whisper.ex:85: anonymous fn/4 in Bumblebee.Audio.SpeechToTextWhisper.speech_to_text_whisper/5
    (nx 0.9.0) lib/nx/serving.ex:1833: anonymous fn/2 in Nx.Serving.Default.handle_batch/3
    (nx 0.9.0) lib/nx/serving.ex:1609: anonymous fn/5 in Nx.Serving.server_maybe_task/1
    (telemetry 1.3.0) /Users/kevinschweikert/Library/Caches/mix/installs/elixir-1.17.2-erts-15.0/6fefd2c3b92204308e47b1f11ca721d5/deps/telemetry/src/telemetry.erl:324: :telemetry.span/3
Function: #Function<32.20405337/0 in Nx.Serving.server_maybe_task/1>
    Args: []

I will investigate further to get the stderr from iree_compile in this case

polvalente commented 1 month ago

Unfortunately, this is because there's still a substantial chunk of the Nx API that isn't compiling properly with IREE: https://github.com/elixir-nx/nx_iree/blob/8901b07671db830cdcee5fea472330437e1a902a/test/nx_iree/nx_test.exs#L4

For Whisper specifically, we need to figure out how to support fft and friends given that stablehlo.fft is not supported by IREE

polvalente commented 1 month ago

For FFT, we can look into https://github.com/DTolm/VkFFT