elixir-nx / nx

Multi-dimensional arrays (tensors) and numerical definitions for Elixir
2.66k stars 194 forks source link

Nx.slice/3 lengths parameter fails inside defn while #1476

Closed santiago-imelio closed 6 months ago

santiago-imelio commented 6 months ago

Hello, I'm working on a notebook to implement 2D median filter using Nx. For each pixel of an image, I take a window of a given size centered on the pixel and compute the median, so the pixel value will be replaced with the median.

For this I used two while loops to iterate over all pixels of the image, and Nx.slice/3 to get the window. The lengths of the slice are dynamic and depend on values that I pass to the while context tuple.

However, passing the lengths to Nx.slice within the while loop block gives this error. Is there something I need to do before passing the parameters to Nx.slice? Or is it just that Nx.slice is not supported on while loops?

Thanks in advance!

** (ArgumentError) length at axis 0 must be greater than or equal to 1, got: #Nx.Tensor<
  f32

  Nx.Defn.Expr
  parameter a:0                s64
  parameter b:6                f32
  parameter d:5                s64
  parameter g:3                s64
  c = add a, b                 f32
  e = subtract d, 1            s64
  f = greater c, e             u8
  h = add a, b                 f32
  i = subtract d, 1            s64
  j = subtract h, i            f32
  k = subtract g, j            f32
  l = as_type g                f32
  m = cond f -> k, true -> l   f32
>
    (nx 0.7.2) lib/nx/shape.ex:1257: Nx.Shape.do_slice/7
    (nx 0.7.2) lib/nx.ex:13591: Nx.slice/4
    #cell:rt2tpkxfftwlpodw:54: anonymous fn/2 in FiltersV2."__defn:median_filter_2d__"/2
    (nx 0.7.2) lib/nx/defn/expr.ex:519: Nx.Defn.Expr.while_vectorized/7
    #cell:rt2tpkxfftwlpodw:21: anonymous fn/2 in FiltersV2."__defn:median_filter_2d__"/2
    (nx 0.7.2) lib/nx/defn/expr.ex:519: Nx.Defn.Expr.while_vectorized/7
    #cell:rt2tpkxfftwlpodw:16: FiltersV2."__defn:median_filter_2d__"/2
    #cell:rm3xkh3wjjldz7mf:2: (file)
polvalente commented 6 months ago

You are trying to pass tensors as slice lengths, and that is not supported. This is because Nx doesn't currently support dynamic shapes

polvalente commented 6 months ago

Upon re-reading the issue, I believe dynamic length means "passed as an argument". If that's the case, you can pass it via a keyword option list as a literal number:

defn f(a, opts \\ []) do
opts = keyword!(opts, [:window_dims])

{m, n} = opts[:window_dims]
...

As I had read before, I had interpreted dynamic as depending on the tensor values, which is what prompted my first response.