elixir-nx / nx

Multi-dimensional arrays (tensors) and numerical definitions for Elixir
2.55k stars 190 forks source link

Problem with `Nx.LinAlg.norm`'s `axes` option #1486

Open LostKobrakai opened 1 month ago

LostKobrakai commented 1 month ago

Copy from a discussion on slack:

t = #Nx.Tensor<
  f64[3][2]
  [
    [-1.0, -1.0],
    [0.0, 0.0],
    [1.0, 1.0]
  ]
>

Nx.LinAlg.norm(t, axes: [1])

This fails with:

** (ArgumentError) cannot broadcast tensor of dimensions {3, 2} to {3}
    (nx 0.7.2) lib/nx/shape.ex:345: Nx.Shape.binary_broadcast/4
    (nx 0.7.2) lib/nx.ex:5407: Nx.devectorized_element_wise_bin_op/4
    (nx 0.7.2) lib/nx/lin_alg.ex:402: Nx.LinAlg.norm_integer/3
    (nx 0.7.2) lib/nx/defn/compiler.ex:173: Nx.Defn.Compiler.runtime_fun/3
    (nx 0.7.2) lib/nx/defn/evaluator.ex:87: Nx.Defn.Evaluator.precompile/3
    (nx 0.7.2) lib/nx/defn/evaluator.ex:65: Nx.Defn.Evaluator.__compile__/4
    (nx 0.7.2) lib/nx/defn/evaluator.ex:58: Nx.Defn.Evaluator.__jit__/5
    /Volumes/benni/Livebook/voronoi.livemd#cell:ptoi3arwh7gaepjoxh74lpa3sxz7zlw7:5: (file)

I think the expected result should be:

#Nx.Tensor<
  f64[3]
  [
    1.4142135623730951,
    0.0,
    1.4142135623730951
  ]
>