elixir-nx / axon

Nx-powered Neural Networks
Apache License 2.0
1.55k stars 103 forks source link

Examples not working #199

Closed PhillippOhlandt closed 2 years ago

PhillippOhlandt commented 2 years ago

Hello,

I wanted to try out some examples in the examples folder since the code in the mnist livebook is outdated.

But it seems like there are some issues with the examples. I haven't tried all but I guess they are all affected. I did change the dependencies of each example I tested:

Here are some logs:

$ elixir examples/basics/multi_input_example.exs
** (Axon.CompilerError) error while building prediction for sigmoid:

** (Axon.CompilerError) error while building prediction for dense:

** (Axon.CompilerError) error while building prediction for tanh:

** (Axon.CompilerError) error while building prediction for dense:

** (ArgumentError) dot/zip expects shapes to be compatible, dimension 1 of left-side (2) does not equal dimension 0 of right-side (8)

    (nx 0.1.0) lib/nx/shape.ex:467: Nx.Shape.validate_zip_reduce_axes!/4
    (nx 0.1.0) lib/nx/shape.ex:442: Nx.Shape.zip_reduce/6
    (nx 0.1.0) lib/nx/shape.ex:1496: Nx.Shape.dot/8
    (nx 0.1.0) lib/nx.ex:7091: Nx.dot/6
    (axon 0.1.0-dev) lib/axon/layers.ex:107: Axon.Layers."__defn:dense__"/3
    (axon 0.1.0-dev) lib/axon/compiler.ex:629: Axon.Compiler.recur_predict_fun/6
    (axon 0.1.0-dev) lib/axon/compiler.ex:382: Axon.Compiler.to_predict_fun/6
    (axon 0.1.0-dev) lib/axon/compiler.ex:533: Axon.Compiler.recur_predict_fun/6
$ elixir examples/basics/multi_output_example.exs
** (Axon.CompilerError) error while building prediction for dense:

** (Axon.CompilerError) error while building prediction for relu:

** (Axon.CompilerError) error while building prediction for dense:

** (Axon.CompilerError) error while building prediction for relu:

** (Axon.CompilerError) error while building prediction for dense:

** (ArgumentError) dot/zip expects shapes to be compatible, dimension 1 of left-side (1) does not equal dimension 0 of right-side (64)

    (nx 0.1.0) lib/nx/shape.ex:467: Nx.Shape.validate_zip_reduce_axes!/4
    (nx 0.1.0) lib/nx/shape.ex:442: Nx.Shape.zip_reduce/6
    (nx 0.1.0) lib/nx/shape.ex:1496: Nx.Shape.dot/8
    (nx 0.1.0) lib/nx.ex:7091: Nx.dot/6
    (axon 0.1.0-dev) lib/axon/layers.ex:107: Axon.Layers."__defn:dense__"/3
    (axon 0.1.0-dev) lib/axon/compiler.ex:629: Axon.Compiler.recur_predict_fun/6
    (axon 0.1.0-dev) lib/axon/compiler.ex:382: Axon.Compiler.to_predict_fun/6
    (axon 0.1.0-dev) lib/axon/compiler.ex:533: Axon.Compiler.recur_predict_fun/6

For the horses or humans example, I kept the github dependency on nx because it complained about undefined functions otherwise:

==> axon
Compiling 20 files (.ex)
warning: Nx.slice_along_axis/4 is undefined or private. Did you mean:

      * slice/3
      * slice_axis/4
      * slice_axis/5
      * take_along_axis/2
      * take_along_axis/3

  lib/axon.ex:1654: Axon.split/3

warning: Nx.slice_along_axis/4 is undefined or private. Did you mean:

      * slice/3
      * slice_axis/4
      * slice_axis/5
      * take_along_axis/2
      * take_along_axis/3

Invalid call found at 3 locations:
  lib/axon/recurrent.ex:150: Axon.Recurrent."__defn:split_gates__"/1
  lib/axon/recurrent.ex:174: Axon.Recurrent."__defn:dynamic_unroll__"/6
  lib/axon/recurrent.ex:196: Axon.Recurrent."__defn:static_unroll__"/6

So here is the log output with nx from github and exla completely removed from the example (and yes, I downloaded the dataset):

$ elixir examples/vision/horses_or_humans.exs
--------------------------------------------------------------------------------------
                                        Model
======================================================================================
 Layer                                               Shape                 Parameters
======================================================================================
 input_0 ( input )                                   {nil, 4, 300, 300}    0
 conv_0 ( conv[ "input_0" ] )                        {nil, 16, 298, 298}   592
 relu_0 ( relu[ "conv_0" ] )                         {nil, 16, 298, 298}   0
 max_pool_0 ( max_pool[ "relu_0" ] )                 {nil, 16, 149, 149}   0
 conv_1 ( conv[ "max_pool_0" ] )                     {nil, 32, 147, 147}   4640
 relu_1 ( relu[ "conv_1" ] )                         {nil, 32, 147, 147}   0
 spatial_dropout_0 ( spatial_dropout[ "relu_1" ] )   {nil, 32, 147, 147}   0
 max_pool_1 ( max_pool[ "spatial_dropout_0" ] )      {nil, 32, 73, 73}     0
 conv_2 ( conv[ "max_pool_1" ] )                     {nil, 64, 71, 71}     18496
 relu_2 ( relu[ "conv_2" ] )                         {nil, 64, 71, 71}     0
 spatial_dropout_1 ( spatial_dropout[ "relu_2" ] )   {nil, 64, 71, 71}     0
 max_pool_2 ( max_pool[ "spatial_dropout_1" ] )      {nil, 64, 35, 35}     0
 conv_3 ( conv[ "max_pool_2" ] )                     {nil, 64, 33, 33}     36928
 relu_3 ( relu[ "conv_3" ] )                         {nil, 64, 33, 33}     0
 max_pool_3 ( max_pool[ "relu_3" ] )                 {nil, 64, 16, 16}     0
 conv_4 ( conv[ "max_pool_3" ] )                     {nil, 64, 14, 14}     36928
 relu_4 ( relu[ "conv_4" ] )                         {nil, 64, 14, 14}     0
 max_pool_4 ( max_pool[ "relu_4" ] )                 {nil, 64, 7, 7}       0
 flatten_0 ( flatten[ "max_pool_4" ] )               {nil, 3136}           0
 dropout_0 ( dropout[ "flatten_0" ] )                {nil, 3136}           0
 dense_0 ( dense[ "dropout_0" ] )                    {nil, 512}            1606144
 relu_5 ( relu[ "dense_0" ] )                        {nil, 512}            0
 dense_1 ( dense[ "relu_5" ] )                       {nil, 2}              1026
 softmax_0 ( softmax[ "dense_1" ] )                  {nil, 2}              0
--------------------------------------------------------------------------------------

Training model without gradient centralization

** (Axon.CompilerError) error while building prediction for softmax:

** (Axon.CompilerError) error while building prediction for dense:

** (Axon.CompilerError) error while building prediction for relu:

** (Axon.CompilerError) error while building prediction for dense:

** (Axon.CompilerError) error while building prediction for dropout:

** (Axon.CompilerError) error while building prediction for flatten:

** (Axon.CompilerError) error while building prediction for max_pool:

** (Axon.CompilerError) error while building prediction for relu:

** (Axon.CompilerError) error while building prediction for conv:

** (Axon.CompilerError) error while building prediction for max_pool:

** (Axon.CompilerError) error while building prediction for relu:

** (Axon.CompilerError) error while building prediction for conv:

** (Axon.CompilerError) error while building prediction for max_pool:

** (Axon.CompilerError) error while building prediction for spatial_dropout:

** (Axon.CompilerError) error while building prediction for relu:

** (Axon.CompilerError) error while building prediction for conv:

** (Axon.CompilerError) error while building prediction for max_pool:

** (Axon.CompilerError) error while building prediction for spatial_dropout:

** (Axon.CompilerError) error while building prediction for relu:

** (Axon.CompilerError) error while building prediction for conv:

** (Axon.CompilerError) error while building prediction for max_pool:

** (Axon.CompilerError) error while building prediction for relu:

** (Axon.CompilerError) error while building prediction for conv:

** (ArgumentError) size of input channels divided by feature groups must match size of kernel channels, got 4 / 1 != 64 for shapes {32, 4, 300, 300} and {64, 64, 3, 3}

    (nx 0.1.0) lib/nx/shape.ex:713: Nx.Shape.validate_conv_groups!/4
    (nx 0.1.0) lib/nx/shape.ex:586: Nx.Shape.conv/13
    (nx 0.1.0) lib/nx.ex:7844: Nx.conv/3
    (axon 0.1.0-dev) lib/axon/layers.ex:310: Axon.Layers."__defn:conv__"/4
    (axon 0.1.0-dev) lib/axon/compiler.ex:921: Axon.Compiler.recur_predict_fun/6
    (axon 0.1.0-dev) lib/axon/compiler.ex:382: Axon.Compiler.to_predict_fun/6
    (axon 0.1.0-dev) lib/axon/compiler.ex:533: Axon.Compiler.recur_predict_fun/6
    (axon 0.1.0-dev) lib/axon/compiler.ex:382: Axon.Compiler.to_predict_fun/6
seanmor5 commented 2 years ago

I have fixed this issue on the main branch. You can safely ignore the warnings EXLA spits out in WSL. They are all warnings from upstream. Please reopen if you observe any more issues!