When the serving is partitioned and the user passes inputs exceeding the batch size, it we may need to concatenate computation results from multiple devices, which fails. Since we generally transfer to Nx.BinaryBackend in post-processing whenever we need to slice the result, we decided to just transfer it upfront in the serving function, which also solves the concatenate case.
See https://github.com/elixir-nx/xla/issues/58#issuecomment-1808248714.
When the serving is partitioned and the user passes inputs exceeding the batch size, it we may need to concatenate computation results from multiple devices, which fails. Since we generally transfer to Nx.BinaryBackend in post-processing whenever we need to slice the result, we decided to just transfer it upfront in the serving function, which also solves the concatenate case.