SciML / DiffEqFlux.jl

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
https://docs.sciml.ai/DiffEqFlux/stable
MIT License
871 stars 156 forks source link

Neural Graph Differential Equations example not working #684

Closed gcrth closed 2 years ago

gcrth commented 2 years ago

The example in https://diffeqflux.sciml.ai/dev/examples/neural_gde/ can not run. There will be an error when running the last line of code, and the error is shown below.

julia> for i = 1:epochs
           Flux.train!(loss, ps, train_data, opt, cb=throttle(evalcb, 10))
       end
┌ Warning: Expected 7344092 params, got 272
└ @ Flux ~/.julia/packages/Flux/BPPNj/src/utils.jl:647
ERROR: LoadError: DimensionMismatch("array could not be broadcast to match destination")

All packages I use are up to date.

I cannot find any solution for solving this, and I wonder if there is any suggestion for fixing this.

There are a few small problems before the last line, but they can be fixed easily with the following steps.

  1. using LightGraphs: adjacency_matrix should be using Graphs: adjacency_matrix.
  2. The link to data may not work, but it can be found in https://github.com/FluxML/GeometricFlux.jl/tree/master/data.

Any help is appreciated, thank you.

ChrisRackauckas commented 2 years ago

@avik-pal could you take a sec to look into this one?

avik-pal commented 2 years ago

@ChrisRackauckas I am not exactly sure what is going on here

If I pass features outside gradient() the sizes line up (2, 4) but inside the gradient block sol.u[1] is flattened out to size 8 (https://github.com/SciML/DiffEqFlux.jl/blob/cff4c35555c9d7f24803a43ae28bd56b445fceaa/test/neural_gde.jl#L30)

ChrisRackauckas commented 2 years ago

@yuehhua was there a change to sizing?

yuehhua commented 2 years ago

Sorry, I will take care of this issue. There are some changes to GCNConv layer.

ChrisRackauckas commented 2 years ago

For now I'm disabling the graph conv test since I couldn't quickly figure it out https://github.com/SciML/DiffEqFlux.jl/pull/693