elixir-nx / nx

Multi-dimensional arrays (tensors) and numerical definitions for Elixir
2.66k stars 194 forks source link

Feat tensor inspect pretty #1561

Closed fredguth closed 23 hours ago

fredguth commented 4 days ago

This PR is not for merge, but for discussion. Nx.tensor inspect with pretty option renders like lovely_tensors python library. e.g.

Nx.Tensor<f32[2][128] n=256 x∈[0.006, 0.996] μ=0.517 σ=0.291>

I find it more useful than the current implementation. Another option is to have a custom_options[lovely: true] instead.

polvalente commented 4 days ago

I don't agree with :pretty making it so that no data from the tensor is printed at all

fredguth commented 3 days ago

@polvalente Thanks for the review. The main idea is to transmit more information about the tensor in less text. Humans tend not to grok raw data. The stats chose just give an overall sense of the domain and values and are there are many cases where they are meaningless. But 2 lines of the raw data are even less meaningful.

I also don't know if its a good idea to use it as :pretty if pretty is being used for debug purposes and I did not even consider the Nx.Defn.Expr problem. The other option that I know of is using custom_options. Can I use custom_options without changing the Nx.tensor code?

I am new to Elixir, my experience is with PyTorch where I always "mokeypatch" the string representation of tensors using lovely_tensors. I find these stats very useful. If there is a way of changing the string representation without messing with Nx.tensor, I would prefer.

polvalente commented 23 hours ago

Here's a suggestion that can move things forward without any changes to Nx itself:

Mix.install [:nx], consolidate_protocols: false

defmodule PrettyTensor do
  defstruct [:tensor]

  def inspect(%Nx.Tensor{data: %__MODULE__{tensor: tensor}}, _inspect_opts) do
    import Inspect.Algebra
    max = Nx.reduce_max(tensor) |> Nx.to_number() |> format_number()
    min = Nx.reduce_min(tensor) |> Nx.to_number() |> format_number()
    mean = Nx.mean(tensor) |> Nx.to_number() |> format_number()
    std = Nx.standard_deviation(tensor) |> Nx.to_number() |> format_number()

    concat([
      "x in [" <> min <> ", " <> max <> "]",
      line(),
      "mean(x): " <> mean,
      line(),
      "standard_deviation(x): " <> std
    ])
  end

  defimpl Inspect do
    def inspect(%PrettyTensor{tensor: %backend{} = tensor}, inspect_opts) do
      cond do
        inspect_opts.pretty and backend != Nx.Defn.Expr  ->
          Inspect.Nx.Tensor.inspect(%{tensor | data: %PrettyTensor{tensor: tensor}}, inspect_opts)

        true ->
          Inspect.Nx.Tensor.inspect(tensor, inspect_opts)
      end
    end
  end

  defp format_number(n) when is_integer(n), do: Integer.to_string(n)
  defp format_number(n) when is_float(n), do: :erlang.float_to_binary(n, decimals: 3)
end

t = Nx.tensor([1, 2, 3, 4])

IO.inspect(%PrettyTensor{tensor: t})

IO.inspect(%PrettyTensor{tensor: t}, pretty: true)

which prints:

#Nx.Tensor<
  s32[4]
  [1, 2, 3, 4]
>
#Nx.Tensor<
  s32[4]
  x in [1, 4]
  mean(x): 2.500
  standard_deviation(x): 1.118
>

Instead of making changes to the Tensor inspect implementations, you can leverage them from a separate struct like above. Note that we're overriding :data with our custom struct so that we keep the tensor inspect container while we provide custom data for the inspect.

This could end up including the original tensor data as well if desired.

fredguth commented 16 hours ago

Thanks for the review. I will proceed in the way suggested.