elixir-nx / bumblebee

Pre-trained Neural Network models in Axon (+ 🤗 Models integration)
Apache License 2.0
1.33k stars 96 forks source link

"cannot perform operation across devices mps and cpu" when running examples with torchx and mps #275

Closed lambadalambda closed 10 months ago

lambadalambda commented 10 months ago

I tested a few examples in the documentation with the Torchx backend set to use :mps. They all fail with this error:

** (ArgumentError) cannot perform operation across devices mps and cpu
    (torchx 0.6.3) lib/torchx.ex:485: anonymous fn/2 in Torchx.prepare_tensors!/1
    (elixir 1.14.3) lib/enum.ex:1780: Enum."-map_reduce/3-lists^mapfoldl/2-0-"/3
    (elixir 1.14.3) lib/enum.ex:1780: Enum."-map_reduce/3-lists^mapfoldl/2-0-"/3
    (torchx 0.6.3) lib/torchx.ex:370: Torchx.clip/3
    (elixir 1.14.3) lib/enum.ex:1662: anonymous fn/3 in Enum.map/2
    (elixir 1.14.3) lib/enum.ex:4299: Enum.reduce_range/5
    (elixir 1.14.3) lib/enum.ex:2472: Enum.map/2
    (torchx 0.6.3) lib/torchx/backend.ex:529: Torchx.Backend.indices_from_nx/2

Here's example code to trigger it.

Mix.install([
  {:bumblebee, "~> 0.4.2"},
  {:torchx, ">= 0.0.0"}
])

Nx.default_backend({Torchx.Backend, device: :mps})

{:ok, bert} = Bumblebee.load_model({:hf, "bert-base-uncased"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "bert-base-uncased"})

serving = Bumblebee.Text.fill_mask(bert, tokenizer)
text = "The capital of [MASK] is Paris."

Nx.Serving.run(serving, text)
|> IO.inspect()

I understand that mps support is still in an experimental state, but I've been doing accelerated stable diffusion with python and libtorch for months now, so this should definitely be possible.

josevalim commented 10 months ago

This is tracked here: https://github.com/elixir-nx/nx/issues/679

PRs are welcome. :)